Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[HLSL] Update Sema Checking Diagnostics for builtins #138429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
Loading
from

Conversation

spall
Copy link
Contributor

@spall spall commented May 4, 2025

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels May 4, 2025
@llvmbot
Copy link
Member

llvmbot commented May 4, 2025

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721


Patch is 61.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138429.diff

21 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1-1)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+131-185)
  • (modified) clang/test/SemaHLSL/BuiltIns/AddUint64-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+10)
  • (modified) clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl (+16-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+9-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/degrees-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+17-16)
  • (modified) clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors2.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl (+6-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+10-10)
  • (modified) clang/test/SemaHLSL/BuiltIns/logical-operator-errors.hlsl (+6)
  • (modified) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+11-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/radians-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl (+4-5)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+5-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/step-errors.hlsl (+4-4)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ccb14e9927adf..c94d34e0259be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12649,7 +12649,7 @@ def err_builtin_invalid_arg_type: Error<
   // An 'or' if non-empty second and third components are combined
   "%plural{0:|:%plural{0:|:or }2}3"
   // Third component: floating-point types
-  "%select{|floating-point}3"
+  "%select{|floating-point|16 or 32 bit floating-point}3"
   // A space after a non-empty third component
   "%plural{0:|: }3"
   "%plural{[0,3]:type|:types}1 (was %4)">;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 70aacaa2aadbe..6486ef765f32e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2005,68 +2005,6 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
 }
 
-// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
-  assert(TheCall->getNumArgs() > 1);
-  ExprResult A = TheCall->getArg(0);
-
-  QualType ArgTyA = A.get()->getType();
-
-  auto *VecTyA = ArgTyA->getAs<VectorType>();
-  SourceLocation BuiltinLoc = TheCall->getBeginLoc();
-
-  bool AllBArgAreVectors = true;
-  for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
-    ExprResult B = TheCall->getArg(i);
-    QualType ArgTyB = B.get()->getType();
-    auto *VecTyB = ArgTyB->getAs<VectorType>();
-    if (VecTyB == nullptr)
-      AllBArgAreVectors &= false;
-    if (VecTyA && VecTyB == nullptr) {
-      // Note: if we get here 'B' is scalar which
-      // requires a VectorSplat on ArgN
-      S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-          << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-      return true;
-    }
-    if (VecTyA && VecTyB) {
-      bool retValue = false;
-      if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(),
-                                             VecTyB->getElementType())) {
-        // Note: type promotion is intended to be handeled via the intrinsics
-        //  and not the builtin itself.
-        S->Diag(TheCall->getBeginLoc(),
-                diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
-        // You should only be hitting this case if you are calling the builtin
-        // directly. HLSL intrinsics should avoid this case via a
-        // HLSLVectorTruncation.
-        S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (retValue)
-        return retValue;
-    }
-  }
-
-  if (VecTyA == nullptr && AllBArgAreVectors) {
-    // Note: if we get here 'A' is a scalar which
-    // requires a VectorSplat on Arg0
-    S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-        << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-        << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2094,63 +2032,46 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
       return true;
-    }
   }
   return false;
 }
 
-static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkAllFloatTypes);
-}
+static bool CheckFloatOrHalfVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
 
-static bool CheckUnsignedIntRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkUnsignedInteger = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isUnsignedIntegerType();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkUnsignedInteger);
+  if (!PassedType->getAs<VectorType>() ||
+      !(EltTy->isHalfType() || EltTy->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
-static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkFloatorHalf);
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->getAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
@@ -2164,30 +2085,49 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
   return true;
 }
 
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
-  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (const auto *VecTy = PassedType->getAs<VectorType>())
-      return VecTy->getElementType()->isDoubleType();
-    return false;
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkDoubleVector);
+static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                                 clang::QualType PassedType) {
+  if (const auto *VecTy = PassedType->getAs<VectorType>())
+    if (VecTy->getElementType()->isDoubleType())
+      return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
+             << PassedType;
+  return false;
 }
-static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasIntegerRepresentation() &&
-           !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                    checkAllSignedTypes);
+
+static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType) {
+  if (!PassedType->hasIntegerRepresentation() &&
+      !PassedType->hasFloatingRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
+           << /* fp */ 1 << PassedType;
+  return false;
 }
 
-static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasUnsignedIntegerRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkAllUnsignedTypes);
+static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
+
+  if (!PassedType->getAs<VectorType>() || !EltTy->isUnsignedIntegerType())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
+           << PassedType;
+  return false;
+}
+
+static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  if (!PassedType->hasUnsignedIntegerRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
+           << /* no fp */ 0 << PassedType;
+  return false;
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2343,23 +2283,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_adduint64: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
-    if (CheckUnsignedIntRepresentations(&SemaRef, TheCall))
-      return true;
-
-    // CheckVectorElementCallArgs(...) guarantees both args are the same type.
-    assert(TheCall->getArg(0)->getType() == TheCall->getArg(1)->getType() &&
-           "Both args must be of the same type");
 
-    // ensure both args are vectors
-    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-    if (!VTy) {
-      SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*all*/ 1;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckUnsignedIntVecRepresentation))
       return true;
-    }
 
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
     // ensure arg integers are 32-bits
     uint64_t ElementBitCount = getASTContext()
                                    .getTypeSizeInChars(VTy->getElementType())
@@ -2380,6 +2309,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
+    // ensure first arg and second arg have the same type
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
+
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2431,10 +2364,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_or: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
       return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
@@ -2446,37 +2379,41 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_any: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_asdouble: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            0))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1))
+      return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
 
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_cross: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+
+    // ensure args are a half3 or float3
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfVecRepresentation))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     // ensure both args have 3 elements
     int NumElementsArg1 =
@@ -2507,13 +2444,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_dot: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
       return true;
-    if (CheckNoDoubleVectors(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
       return true;
     break;
   }
@@ -2560,8 +2493,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
-    if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (!TheCall->getArg(0)
+             ->getType()
+             ->hasFloatingRepresentation()) // half or float or double
+      return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+                          diag::err_builtin_invalid_arg_type)
+             << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
+             << /* fp */ 1 << TheCall->getArg(0)->getType();
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
@@ -2570,14 +2510,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_elementwise_radians:
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
@@ -2587,34 +2533,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
-    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_normalize: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
+      return true;
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2622,17 +2562,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
-    if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
-      return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatingOrIntRepresentation))
+      return true;
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 4, 2025

@llvm/pr-subscribers-clang

Author: Sarah Spall (spall)

Changes

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721


Patch is 61.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138429.diff

21 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1-1)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+131-185)
  • (modified) clang/test/SemaHLSL/BuiltIns/AddUint64-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+10)
  • (modified) clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl (+16-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+9-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/degrees-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+17-16)
  • (modified) clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors2.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl (+6-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+10-10)
  • (modified) clang/test/SemaHLSL/BuiltIns/logical-operator-errors.hlsl (+6)
  • (modified) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+11-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/radians-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl (+4-5)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+5-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/step-errors.hlsl (+4-4)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ccb14e9927adf..c94d34e0259be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12649,7 +12649,7 @@ def err_builtin_invalid_arg_type: Error<
   // An 'or' if non-empty second and third components are combined
   "%plural{0:|:%plural{0:|:or }2}3"
   // Third component: floating-point types
-  "%select{|floating-point}3"
+  "%select{|floating-point|16 or 32 bit floating-point}3"
   // A space after a non-empty third component
   "%plural{0:|: }3"
   "%plural{[0,3]:type|:types}1 (was %4)">;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 70aacaa2aadbe..6486ef765f32e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2005,68 +2005,6 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
 }
 
-// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
-  assert(TheCall->getNumArgs() > 1);
-  ExprResult A = TheCall->getArg(0);
-
-  QualType ArgTyA = A.get()->getType();
-
-  auto *VecTyA = ArgTyA->getAs<VectorType>();
-  SourceLocation BuiltinLoc = TheCall->getBeginLoc();
-
-  bool AllBArgAreVectors = true;
-  for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
-    ExprResult B = TheCall->getArg(i);
-    QualType ArgTyB = B.get()->getType();
-    auto *VecTyB = ArgTyB->getAs<VectorType>();
-    if (VecTyB == nullptr)
-      AllBArgAreVectors &= false;
-    if (VecTyA && VecTyB == nullptr) {
-      // Note: if we get here 'B' is scalar which
-      // requires a VectorSplat on ArgN
-      S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-          << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-      return true;
-    }
-    if (VecTyA && VecTyB) {
-      bool retValue = false;
-      if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(),
-                                             VecTyB->getElementType())) {
-        // Note: type promotion is intended to be handeled via the intrinsics
-        //  and not the builtin itself.
-        S->Diag(TheCall->getBeginLoc(),
-                diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
-        // You should only be hitting this case if you are calling the builtin
-        // directly. HLSL intrinsics should avoid this case via a
-        // HLSLVectorTruncation.
-        S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (retValue)
-        return retValue;
-    }
-  }
-
-  if (VecTyA == nullptr && AllBArgAreVectors) {
-    // Note: if we get here 'A' is a scalar which
-    // requires a VectorSplat on Arg0
-    S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-        << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-        << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2094,63 +2032,46 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
       return true;
-    }
   }
   return false;
 }
 
-static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkAllFloatTypes);
-}
+static bool CheckFloatOrHalfVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
 
-static bool CheckUnsignedIntRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkUnsignedInteger = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isUnsignedIntegerType();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkUnsignedInteger);
+  if (!PassedType->getAs<VectorType>() ||
+      !(EltTy->isHalfType() || EltTy->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
-static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkFloatorHalf);
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->getAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
@@ -2164,30 +2085,49 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
   return true;
 }
 
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
-  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (const auto *VecTy = PassedType->getAs<VectorType>())
-      return VecTy->getElementType()->isDoubleType();
-    return false;
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkDoubleVector);
+static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                                 clang::QualType PassedType) {
+  if (const auto *VecTy = PassedType->getAs<VectorType>())
+    if (VecTy->getElementType()->isDoubleType())
+      return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
+             << PassedType;
+  return false;
 }
-static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasIntegerRepresentation() &&
-           !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                    checkAllSignedTypes);
+
+static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType) {
+  if (!PassedType->hasIntegerRepresentation() &&
+      !PassedType->hasFloatingRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
+           << /* fp */ 1 << PassedType;
+  return false;
 }
 
-static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasUnsignedIntegerRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkAllUnsignedTypes);
+static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
+
+  if (!PassedType->getAs<VectorType>() || !EltTy->isUnsignedIntegerType())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
+           << PassedType;
+  return false;
+}
+
+static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  if (!PassedType->hasUnsignedIntegerRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
+           << /* no fp */ 0 << PassedType;
+  return false;
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2343,23 +2283,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_adduint64: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
-    if (CheckUnsignedIntRepresentations(&SemaRef, TheCall))
-      return true;
-
-    // CheckVectorElementCallArgs(...) guarantees both args are the same type.
-    assert(TheCall->getArg(0)->getType() == TheCall->getArg(1)->getType() &&
-           "Both args must be of the same type");
 
-    // ensure both args are vectors
-    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-    if (!VTy) {
-      SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*all*/ 1;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckUnsignedIntVecRepresentation))
       return true;
-    }
 
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
     // ensure arg integers are 32-bits
     uint64_t ElementBitCount = getASTContext()
                                    .getTypeSizeInChars(VTy->getElementType())
@@ -2380,6 +2309,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
+    // ensure first arg and second arg have the same type
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
+
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2431,10 +2364,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_or: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
       return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
@@ -2446,37 +2379,41 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_any: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_asdouble: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            0))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1))
+      return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
 
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_cross: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+
+    // ensure args are a half3 or float3
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfVecRepresentation))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     // ensure both args have 3 elements
     int NumElementsArg1 =
@@ -2507,13 +2444,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_dot: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
       return true;
-    if (CheckNoDoubleVectors(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
       return true;
     break;
   }
@@ -2560,8 +2493,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
-    if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (!TheCall->getArg(0)
+             ->getType()
+             ->hasFloatingRepresentation()) // half or float or double
+      return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+                          diag::err_builtin_invalid_arg_type)
+             << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
+             << /* fp */ 1 << TheCall->getArg(0)->getType();
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
@@ -2570,14 +2510,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_elementwise_radians:
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
@@ -2587,34 +2533,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
-    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_normalize: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
+      return true;
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2622,17 +2562,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
-    if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
-      return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatingOrIntRepresentation))
+      return true;
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall...
[truncated]

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! One suggestion and a one question why the diag message uses has double when the argument is float.

@@ -42,47 +42,47 @@ float2 test_mad_element_type_mismatch(half2 p0, float2 p1) {

float2 test_builtin_mad_float2_splat(float p0, float2 p1) {
return __builtin_hlsl_mad(p0, p1, p1);
// expected-error@-1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
// expected-error@-1 {{arguments are of different types ('double' vs 'float2' (aka 'vector<float, 2>'))}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reporting double in the error message while the function argument is float, and I see the same thing on lines 50, 55 and 85.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's because float arguments to variadic functions are promoted to double.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make __builtin_hlsl_mad CustomTypeChecking. I changed a bunch of these that were just float\half definitions. I left the ones that had ints because I wasn't as sure what to do, but this makes it more clear.

clang/lib/Sema/SemaHLSL.cpp Outdated Show resolved Hide resolved
Comment on lines +2086 to 2094
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
clang::QualType PassedType) {
if (const auto *VecTy = PassedType->getAs<VectorType>())
if (VecTy->getElementType()->isDoubleType())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
<< PassedType;
return false;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small Nit: I really like to try to avoid nesting if possible. seems like we might be able to if we do the negation checks first.

Suggested change
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
clang::QualType PassedType) {
if (const auto *VecTy = PassedType->getAs<VectorType>())
if (VecTy->getElementType()->isDoubleType())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
<< PassedType;
return false;
}
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();
if (!VecTy)
return false;
if (!VecTy->getElementType()->isDoubleType())
return false;
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
<< PassedType;
}

break;
}
case Builtin::BI__builtin_hlsl_asdouble: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
0)) // only check for uint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an inline arg comment so we can distinguish 0 from 1 args on lines 2387 and 2389

Comment on lines +2389 to +2390
if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
1)) // only check for uint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
1)) // only check for uint
if (CheckScalarOrVector(&SemaRef, TheCall,
/*only check for uint*/ SemaRef.Context.UnsignedIntTy,
1))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or remove the comment about line 2386 so you aren't duplicating the comment on line 2387 and 2390

return true;
break;
}
case Builtin::BI__builtin_hlsl_cross: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
if (CheckVectorElementCallArgs(&SemaRef, TheCall))

// ensure args are a half3 or float3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn not seem to match what CheckAllArgTypesAreCorrect does since there is no specific vec3 checking. Now that I am seeing this. I don't think cross product should be handled in sema since we know it is only valid for vec3s we should probably just create a typed builtin like so

def HLSLCrossHalf: Builtin { 
  let Spellings = ["__builtin_hlsl_crossf16"];
  let Attributes = [NoThrow, Const];
   let Prototype = "_ExtVector<3, _Float16>(_ExtVector<3, _Float16>, _ExtVector<3, _Float16>)";
}
}

def HLSLCrossFloat: Builtin { 
  let Spellings = ["__builtin_hlsl_crossf];
  let Attributes = [NoThrow, Const];
  let Prototype = "_ExtVector<3, float>(_ExtVector<3, float>, _ExtVector<3, float>)";
}

or

class HLSLFPMathTemplate : Template<["_Float16", "float"],  ["f16",    "f"]>;

def HLSLCross : Builtin, HLSLFPMathTemplate {
  let Spellings = ["__builtin_hlsl_cross"];
  let Attributes = [FunctionWithBuiltinPrefix, NoThrow,Const];
  let Prototype = "_ExtVector<3, T>(_ExtVector<3, T>, _ExtVector<3, T>)";
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is a weird one to typecheck. The comment was really meant to just refer to that entire block of code below it, rather than just the CheckAllArgTypesAreCorrect.

@@ -2507,13 +2441,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
break;
}
case Builtin::BI__builtin_hlsl_dot: {
if (SemaRef.checkArgCount(TheCall, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this? Since you are leaving it for many of the other builtins this one now looks like the odd man out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BuiltinVectorToScalarMath checks that the arg count is 2. I removed checkArgCount checks in places where the next thing checked was something like BuiltinVectorMath, which checks arg count.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] Update HLSL's Sema Checking and Diagnostics for builtins
7 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.