From 34a11ca5577b49d74daa67ed40b866cefed8b45c Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Tue, 27 May 2025 19:10:48 -0700 Subject: [PATCH 01/13] Provide supports for decomposable structs --- clang/lib/Sema/SemaSYCL.cpp | 206 +++++++++++------- .../experimental/free_function_traits.hpp | 10 + sycl/include/sycl/handler.hpp | 4 +- 3 files changed, 146 insertions(+), 74 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 9a30b3e693ec2..7c2f2762e2133 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -8,6 +8,7 @@ // This implements Semantic Analysis for SYCL constructs. //===----------------------------------------------------------------------===// +#include #include "clang/Sema/SemaSYCL.h" #include "TreeTransform.h" #include "clang/AST/AST.h" @@ -1387,13 +1388,13 @@ class KernelObjVisitor { template void visitComplexRecord(const CXXRecordDecl *Owner, ParentTy &Parent, const CXXRecordDecl *Wrapper, QualType RecordTy, - HandlerTys &... Handlers) { + HandlerTys &...Handlers) { (void)std::initializer_list{ (Handlers.enterStruct(Owner, Parent, RecordTy), 0)...}; VisitRecordHelper(Wrapper, Wrapper->bases(), Handlers...); - VisitRecordHelper(Wrapper, Wrapper->fields(), Handlers...); - (void)std::initializer_list{ - (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...}; + VisitRecordHelper(Wrapper, Wrapper->fields(), Handlers...), + (void)std::initializer_list{ + (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...}; } template @@ -1499,7 +1500,9 @@ class KernelObjVisitor { void visitField(const CXXRecordDecl *Owner, FieldDecl *Field, QualType FieldTy, HandlerTys &... Handlers) { if (isSyclSpecialType(FieldTy, SemaSYCLRef)) + { FieldTy->dump(); KF_FOR_EACH(handleSyclSpecialType, Field, FieldTy); +} else if (FieldTy->isStructureOrClassType()) { if (KF_FOR_EACH(handleStructType, Field, FieldTy)) { CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); @@ -1526,9 +1529,12 @@ class KernelObjVisitor { void visitParam(ParmVarDecl *Param, QualType ParamTy, HandlerTys &...Handlers) { if (isSyclSpecialType(ParamTy, SemaSYCLRef)) + {ParamTy->dump(); KP_FOR_EACH(handleSyclSpecialType, Param, ParamTy); +} else if (ParamTy->isStructureOrClassType()) { if (KP_FOR_EACH(handleStructType, Param, ParamTy)) { + ParamTy->dump(); CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); visitRecord(nullptr, Param, RD, ParamTy, Handlers...); } @@ -1607,8 +1613,12 @@ class KernelObjVisitor { template void VisitFunctionParameters(FunctionDecl *FreeFunc, HandlerTys &...Handlers) { - for (ParmVarDecl *Param : FreeFunc->parameters()) + for (ParmVarDecl *Param : FreeFunc->parameters()) { +std::cout << "starting!" << std::endl; +Param->getType()->dump(); visitParam(Param, Param->getType(), Handlers...); +std::cout << "ending!" << std::endl; +} } #undef KF_FOR_EACH @@ -1731,10 +1741,6 @@ class SyclKernelFieldHandlerBase { virtual ~SyclKernelFieldHandlerBase() = default; }; - -// A class to act as the direct base for all the SYCL OpenCL Kernel construction -// tasks that contains a reference to Sema (and potentially any other -// universally required data). class SyclKernelFieldHandler : public SyclKernelFieldHandlerBase { protected: SemaSYCL &SemaSYCLRef; @@ -1818,7 +1824,11 @@ void KernelObjVisitor::visitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, // If this container requires decomposition, we have to visit it as // 'complex', so all handlers are called in this case with the 'complex' // case. + //RecordTy->dump(); visitComplexRecord(Owner, Parent, Wrapper, RecordTy, Handlers...); + // 'complex', so all handlers are called in this case with the 'complex' + // case. + //RecordTy->dump(); } else if (AnyTrue:: Value) { // We are currently in PointerHandler visitor. @@ -2141,31 +2151,14 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { } bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO manipulate struct depth once special types are supported for free - // function kernels. - // ++StructFieldDepth; + ++StructFieldDepth; return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { - // TODO manipulate struct depth once special types are supported for free - // function kernels. - // --StructFieldDepth; - // TODO We don't yet support special types and therefore structs that - // require decomposition and leaving/entering. Diagnose for better user - // experience. - CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); - if (RD->hasAttr()) { - Diag.Report(PD->getLocation(), - diag::err_bad_kernel_param_type) - << ParamTy; - Diag.Report(PD->getLocation(), - diag::note_free_function_kernel_param_type_not_supported) - << ParamTy; - IsInvalid = true; - } - return isValid(); + --StructFieldDepth; + return true; } bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &BS, @@ -2269,8 +2262,6 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { } bool handleSyclSpecialType(ParmVarDecl *, QualType) final { - // TODO We don't support special types in free function kernel parameters, - // but track them to diagnose the case properly. CollectionStack.back() = true; return true; } @@ -2542,7 +2533,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType ParamTy) final { // TODO - unsupportedFreeFunctionParamType(); return true; } @@ -2563,7 +2553,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { // TODO - unsupportedFreeFunctionParamType(); return true; } @@ -2660,7 +2649,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { // TODO - unsupportedFreeFunctionParamType(); return true; } @@ -2694,6 +2682,9 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { class SyclKernelDeclCreator : public SyclKernelFieldHandler { FunctionDecl *KernelDecl = nullptr; llvm::SmallVector Params; + // Holds the last handled kernel struct parameter that contains a special type. + // Set in the enterStruct functions. + ParmVarDecl * CurrentStruct; Sema::ContextRAII FuncContext; // Holds the last handled field's first parameter. This doesn't store an // iterator as push_back invalidates iterators. @@ -2711,6 +2702,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { addParam(newParamDesc, ParamTy); } + void addParam(const CXXBaseSpecifier &BS, QualType FieldTy) { // TODO: There is no name for the base available, but duplicate names are // seemingly already possible, so we'll give them all the same name for now. @@ -2798,7 +2790,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { SourceLocation Loc) { handleAccessorPropertyList(Params.back(), RecordDecl, Loc); - // If "accessor" type check if read only + // If "accessor" type check if read only if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::accessor)) { // Get access mode of accessor. const auto *AccessorSpecializationDecl = @@ -2824,6 +2816,8 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { // lambda kernel by taking the value ParmVarDecl or FieldDecl respectively. template bool handleSpecialType(ParentDecl *decl, QualType Ty) { +std::cout << "Important one!" << std::endl; +Ty->dump(); const auto *RD = Ty->getAsCXXRecordDecl(); assert(RD && "The type must be a RecordDecl"); llvm::StringLiteral MethodName = @@ -2837,7 +2831,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { // (if any). size_t ParamIndex = Params.size(); for (const ParmVarDecl *Param : InitMethod->parameters()) { - QualType ParamTy = Param->getType(); + QualType ParamTy = Param->getType(); // For lambda kernels the arguments to the OpenCL kernel are named // based on the position they have as fields in the definition of the // special type structure i.e __arg_field1, __arg_field2 and so on. @@ -2863,6 +2857,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { handleAccessorType(Ty, RD, decl->getBeginLoc()); } LastParamIndex = ParamIndex; + std::cout << LastParamIndex << std::endl; return true; } @@ -2956,6 +2951,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { SYCLKernelAttr::CreateImplicit(SemaSYCLRef.getASTContext())); SemaSYCLRef.addSyclDeviceDecl(KernelDecl); + //KernelDecl->dump(); } bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { @@ -2963,9 +2959,11 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - // ++StructDepth; + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType Ty) final { + ++StructDepth; + //StringRef Name = "_arg_struct"; + //addParam(Name, Ty); + //CurrentStruct = Params.back(); return true; } @@ -2975,8 +2973,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - // --StructDepth; + --StructDepth; return true; } @@ -2992,6 +2989,15 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool handleStructType(ParmVarDecl *PD, QualType Ty) final { + StringRef Name = "_arg_struct"; + addParam(Name, Ty); + CurrentStruct = Params.back(); + return true; + } + + bool handleStructType(FieldDecl *, QualType) final { return true; } + bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { const auto *RecordDecl = FieldTy->getAsCXXRecordDecl(); @@ -3166,6 +3172,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return ArrayRef(std::begin(Params) + LastParamIndex, std::end(Params)); } + ParmVarDecl *getParentStructForCurrentField() { return CurrentStruct; } }; // This Visitor traverses the AST of the function with @@ -3619,8 +3626,11 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { SourceLocation LL = NewBody ? NewBody->getBeginLoc() : SourceLocation(); SourceLocation LR = NewBody ? NewBody->getEndLoc() : SourceLocation(); - return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, + CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, + FPOptionsOverride(), LL, LR)->dumpPretty(SemaSYCLRef.getASTContext()); +return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, FPOptionsOverride(), LL, LR); + } void annotateHierarchicalParallelismAPICalls() { @@ -4342,16 +4352,14 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { SyclKernelDeclCreator &DeclCreator; llvm::SmallVector BodyStmts; + llvm::SmallVector CurrentStructs; FunctionDecl *FreeFunc = nullptr; SourceLocation FreeFunctionSrcLoc; // Free function source location. llvm::SmallVector ArgExprs; - // Creates a DeclRefExpr to the ParmVar that represents the current free - // function parameter. - Expr *createParamReferenceExpr() { - ParmVarDecl *FreeFunctionParameter = - DeclCreator.getParamVarDeclsForCurrentField()[0]; - + // Creates a DeclRefExpr to the ParmVar that represents an arbitrary + // free function parameter + Expr *createParamReferenceExpr(ParmVarDecl *FreeFunctionParameter) { QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType(); Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr( FreeFunctionParameter, FreeFunctionParamType, VK_LValue, @@ -4360,6 +4368,14 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { return DRE; } + // Creates a DeclRefExpr to the ParmVar that represents the current free + // function parameter. + Expr *createParamReferenceExpr() { + ParmVarDecl *FreeFunctionParameter = + DeclCreator.getParamVarDeclsForCurrentField()[0]; + return createParamReferenceExpr(FreeFunctionParameter); + } + // Creates a DeclRefExpr to the ParmVar that represents the current pointer // parameter. Expr *createPointerParamReferenceExpr(QualType PointerTy) { @@ -4416,6 +4432,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { DRE = createReinterpretCastExpr( createGetAddressOf(DRE), SemaSYCLRef.getASTContext().getPointerType( OrigFunctionParameter->getType())); + DRE = createDerefOp(DRE); } @@ -4450,8 +4467,12 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { auto CallExpr = CallExpr::Create(Context, Fn, ArgExprs, ResultTy, VK, FreeFunctionSrcLoc, FPOptionsOverride()); BodyStmts.push_back(CallExpr); +CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, + {})->dumpPretty(Context); + return CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, {}); + } MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) { @@ -4468,15 +4489,17 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { void createSpecialMethodCall(const CXXRecordDecl *RD, StringRef MethodName, Expr *MemberBaseExpr, SmallVectorImpl &AddTo) { - CXXMethodDecl *Method = getMethodByName(RD, MethodName); +CXXMethodDecl *Method = getMethodByName(RD, MethodName); if (!Method) return; unsigned NumParams = Method->getNumParams(); llvm::SmallVector ParamDREs(NumParams); llvm::ArrayRef KernelParameters = DeclCreator.getParamVarDeclsForCurrentField(); + //std::cout << KernelParameters.size() << std::endl; for (size_t I = 0; I < NumParams; ++I) { QualType ParamType = KernelParameters[I]->getOriginalType(); + //ParamType->dump(); ParamDREs[I] = SemaSYCLRef.SemaRef.BuildDeclRefExpr( KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc); } @@ -4495,7 +4518,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { public: static constexpr const bool VisitInsideSimpleContainers = false; - + FreeFunctionKernelBodyCreator(SemaSYCL &S, SyclKernelDeclCreator &DC, FunctionDecl *FF) : SyclKernelFieldHandler(S), DeclCreator(DC), FreeFunc(FF), @@ -4506,9 +4529,20 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { DeclCreator.setBody(KernelBody); } - bool handleSyclSpecialType(FieldDecl *FD, QualType Ty) final { - // TODO - unsupportedFreeFunctionParamType(); + bool handleSyclSpecialType(FieldDecl *FD, QualType FieldTy) final { + // Being inside this function means there is a struct parameter to the free + // function kernel that contains a special type. +std::cout << "Body!" << std::endl; +FieldTy->dump(); + ParmVarDecl *ParentStruct = DeclCreator.getParentStructForCurrentField(); + // special_type_wrapper_map[ParentStruct->getType()] = true; + Expr *Base = createParamReferenceExpr(ParentStruct); + for (const auto &child : CurrentStructs) { + Base = buildMemberExpr(Base, child); + } + MemberExpr *MemberAccess = buildMemberExpr(Base, FD); + createSpecialMethodCall(FieldTy->getAsCXXRecordDecl(), InitMethodName, + MemberAccess, BodyStmts); return true; } @@ -4527,6 +4561,8 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { // wgm.__init(arg); // user_kernel(some arguments..., wgm, some arguments...); // } + std::cout << "Body!" << std::endl; + ParamTy->dump(); const auto *RecordDecl = ParamTy->getAsCXXRecordDecl(); AccessSpecifier DefaultConstructorAccess; auto DefaultConstructor = @@ -4559,8 +4595,8 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { BodyStmts.push_back(DS); Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr( SpecialObjectClone, ParamTy, VK_PRValue, FreeFunctionSrcLoc); - createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr, - BodyStmts); + createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr, + BodyStmts); ArgExprs.push_back(MemberBaseExpr); return true; } @@ -4636,26 +4672,24 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { } bool enterStruct(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final { - // TODO - unsupportedFreeFunctionParamType(); + CurrentStructs.push_back(FD); return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool enterStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, + QualType Ty) final { return true; } bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { - // TODO - unsupportedFreeFunctionParamType(); + CurrentStructs.pop_back(); return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + ParmVarDecl *ParentStruct = DeclCreator.getParentStructForCurrentField(); + ArgExprs.push_back(SemaSYCLRef.SemaRef.BuildDeclRefExpr( + ParentStruct, ParentStruct->getType(), VK_PRValue, FreeFunctionSrcLoc)); return true; } @@ -4700,6 +4734,11 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { unsupportedFreeFunctionParamType(); return true; } + FieldDecl *getCurrentStruct() { + assert(CurrentStructs.size() && + "Current free function parameter is not inside a structure!"); + return CurrentStructs.back(); + } }; // Kernels are only the unnamed-lambda feature if the feature is enabled, AND @@ -4979,7 +5018,6 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { // TODO - unsupportedFreeFunctionParamType(); return true; } @@ -4991,7 +5029,6 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { // TODO - unsupportedFreeFunctionParamType(); return true; } @@ -5488,22 +5525,25 @@ void SemaSYCL::constructFreeFunctionKernel(FunctionDecl *FD, StringRef NameStr) { if (!checkAndAddRegisteredKernelName(*this, FD, NameStr)) return; - SyclKernelArgsSizeChecker argsSizeChecker(*this, FD->getLocation(), false /*IsSIMDKernel*/); SyclKernelDeclCreator kernel_decl(*this, FD->getLocation(), FD->isInlined(), false /*IsSIMDKernel */, FD); - FreeFunctionKernelBodyCreator kernel_body(*this, kernel_decl, FD); - SyclKernelIntHeaderCreator int_header(*this, getSyclIntegrationHeader(), FD->getType(), FD); - SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter()); KernelObjVisitor Visitor{*this}; - Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body, - int_header, int_footer); + Visitor.VisitFunctionParameters(FD, argsSizeChecker); + +Visitor.VisitFunctionParameters(FD, kernel_decl); + +Visitor.VisitFunctionParameters(FD, kernel_body); + +Visitor.VisitFunctionParameters(FD, int_header); + +Visitor.VisitFunctionParameters(FD, int_footer); assert(getKernelFDPairs().back().first == FD && "OpenCL Kernel not found for free function entry"); @@ -6984,6 +7024,26 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { } } ParmListWithNamesOstream.flush(); + for (ParmVarDecl *Param : K.SyclKernel->parameters()) { + // if (FreeFunctionKernelBodyCreator::isSpecialTypeWrapper( + // Param->getType())) { + // this is a struct that contains a special type so its neither a + // special type nor a trivially copyable type. We therefore need to + // explicitly communicate to the runtime that this argument should be + // allowed as a free function kernel argument. We do this by defining + // a certain trait recognized by the runtime to be true. + O << "template <>\n"; + O << "struct " + "sycl::ext::oneapi::experimental::detail::is_explicitly_allowed_" + "arg<"; + Policy.SuppressTagKeyword = true; + + Param->getType().print(O, Policy); + Policy.SuppressTagKeyword = false; + O << "> {\n"; + O << " static constexpr bool value = true;\n};\n"; + //} + } FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); Policy.PrintCanonicalTypes = false; Policy.SuppressDefinition = true; @@ -7720,7 +7780,7 @@ StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap); Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get(); - OFD->setBody(OFDBody); +OFD->setBody(OFDBody); OFD->setNothrow(); Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD); diff --git a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp index 2b5d1f4190d21..0ca1c234c9070 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp @@ -44,6 +44,16 @@ template struct is_kernel { template inline constexpr bool is_kernel_v = is_kernel::value; +namespace detail { +template struct is_explicitly_allowed_arg { + static constexpr bool value = false; +}; + +template +inline constexpr bool is_explicitly_allowed_arg_v = + is_explicitly_allowed_arg::value; + +} // namespace detail } // namespace ext::oneapi::experimental } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index d7b304a130c83..54b093f05000a 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -1766,7 +1767,8 @@ class __SYCL_EXPORT handler { || (!is_same_type::value && std::is_pointer_v>) // USM || is_same_type::value // Interop - || is_same_type::value; // Stream + || is_same_type::value // Stream + || ext::oneapi::experimental::detail::is_explicitly_allowed_arg>::value; }; /// Sets argument for OpenCL interoperability kernels. From c9680a244bdd1f4e8fb979d2032131eba90c4aa3 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Tue, 3 Jun 2025 11:46:24 -0700 Subject: [PATCH 02/13] Revert "Provide supports for decomposable structs" This reverts commit 34a11ca5577b49d74daa67ed40b866cefed8b45c. --- clang/lib/Sema/SemaSYCL.cpp | 206 +++++++----------- .../experimental/free_function_traits.hpp | 10 - sycl/include/sycl/handler.hpp | 4 +- 3 files changed, 74 insertions(+), 146 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b1a7b3cedf70f..cf64331198c91 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -8,7 +8,6 @@ // This implements Semantic Analysis for SYCL constructs. //===----------------------------------------------------------------------===// -#include #include "clang/Sema/SemaSYCL.h" #include "TreeTransform.h" #include "clang/AST/AST.h" @@ -1409,13 +1408,13 @@ class KernelObjVisitor { template void visitComplexRecord(const CXXRecordDecl *Owner, ParentTy &Parent, const CXXRecordDecl *Wrapper, QualType RecordTy, - HandlerTys &...Handlers) { + HandlerTys &... Handlers) { (void)std::initializer_list{ (Handlers.enterStruct(Owner, Parent, RecordTy), 0)...}; VisitRecordHelper(Wrapper, Wrapper->bases(), Handlers...); - VisitRecordHelper(Wrapper, Wrapper->fields(), Handlers...), - (void)std::initializer_list{ - (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...}; + VisitRecordHelper(Wrapper, Wrapper->fields(), Handlers...); + (void)std::initializer_list{ + (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...}; } template @@ -1521,9 +1520,7 @@ class KernelObjVisitor { void visitField(const CXXRecordDecl *Owner, FieldDecl *Field, QualType FieldTy, HandlerTys &... Handlers) { if (isSyclSpecialType(FieldTy, SemaSYCLRef)) - { FieldTy->dump(); KF_FOR_EACH(handleSyclSpecialType, Field, FieldTy); -} else if (FieldTy->isStructureOrClassType()) { if (KF_FOR_EACH(handleStructType, Field, FieldTy)) { CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); @@ -1550,12 +1547,9 @@ class KernelObjVisitor { void visitParam(ParmVarDecl *Param, QualType ParamTy, HandlerTys &...Handlers) { if (isSyclSpecialType(ParamTy, SemaSYCLRef)) - {ParamTy->dump(); KP_FOR_EACH(handleSyclSpecialType, Param, ParamTy); -} else if (ParamTy->isStructureOrClassType()) { if (KP_FOR_EACH(handleStructType, Param, ParamTy)) { - ParamTy->dump(); CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); visitRecord(nullptr, Param, RD, ParamTy, Handlers...); } @@ -1634,12 +1628,8 @@ class KernelObjVisitor { template void VisitFunctionParameters(FunctionDecl *FreeFunc, HandlerTys &...Handlers) { - for (ParmVarDecl *Param : FreeFunc->parameters()) { -std::cout << "starting!" << std::endl; -Param->getType()->dump(); + for (ParmVarDecl *Param : FreeFunc->parameters()) visitParam(Param, Param->getType(), Handlers...); -std::cout << "ending!" << std::endl; -} } #undef KF_FOR_EACH @@ -1762,6 +1752,10 @@ class SyclKernelFieldHandlerBase { virtual ~SyclKernelFieldHandlerBase() = default; }; + +// A class to act as the direct base for all the SYCL OpenCL Kernel construction +// tasks that contains a reference to Sema (and potentially any other +// universally required data). class SyclKernelFieldHandler : public SyclKernelFieldHandlerBase { protected: SemaSYCL &SemaSYCLRef; @@ -1845,11 +1839,7 @@ void KernelObjVisitor::visitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, // If this container requires decomposition, we have to visit it as // 'complex', so all handlers are called in this case with the 'complex' // case. - //RecordTy->dump(); visitComplexRecord(Owner, Parent, Wrapper, RecordTy, Handlers...); - // 'complex', so all handlers are called in this case with the 'complex' - // case. - //RecordTy->dump(); } else if (AnyTrue:: Value) { // We are currently in PointerHandler visitor. @@ -2172,14 +2162,31 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { } bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - ++StructFieldDepth; + // TODO manipulate struct depth once special types are supported for free + // function kernels. + // ++StructFieldDepth; return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { - --StructFieldDepth; - return true; + // TODO manipulate struct depth once special types are supported for free + // function kernels. + // --StructFieldDepth; + // TODO We don't yet support special types and therefore structs that + // require decomposition and leaving/entering. Diagnose for better user + // experience. + CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); + if (RD->hasAttr()) { + Diag.Report(PD->getLocation(), + diag::err_bad_kernel_param_type) + << ParamTy; + Diag.Report(PD->getLocation(), + diag::note_free_function_kernel_param_type_not_supported) + << ParamTy; + IsInvalid = true; + } + return isValid(); } bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &BS, @@ -2283,6 +2290,8 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { } bool handleSyclSpecialType(ParmVarDecl *, QualType) final { + // TODO We don't support special types in free function kernel parameters, + // but track them to diagnose the case properly. CollectionStack.back() = true; return true; } @@ -2554,6 +2563,7 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType ParamTy) final { // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -2574,6 +2584,7 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -2670,6 +2681,7 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -2703,9 +2715,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { class SyclKernelDeclCreator : public SyclKernelFieldHandler { FunctionDecl *KernelDecl = nullptr; llvm::SmallVector Params; - // Holds the last handled kernel struct parameter that contains a special type. - // Set in the enterStruct functions. - ParmVarDecl * CurrentStruct; Sema::ContextRAII FuncContext; // Holds the last handled field's first parameter. This doesn't store an // iterator as push_back invalidates iterators. @@ -2723,7 +2732,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { addParam(newParamDesc, ParamTy); } - void addParam(const CXXBaseSpecifier &BS, QualType FieldTy) { // TODO: There is no name for the base available, but duplicate names are // seemingly already possible, so we'll give them all the same name for now. @@ -2811,7 +2819,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { SourceLocation Loc) { handleAccessorPropertyList(Params.back(), RecordDecl, Loc); - // If "accessor" type check if read only + // If "accessor" type check if read only if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::accessor)) { // Get access mode of accessor. const auto *AccessorSpecializationDecl = @@ -2837,8 +2845,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { // lambda kernel by taking the value ParmVarDecl or FieldDecl respectively. template bool handleSpecialType(ParentDecl *decl, QualType Ty) { -std::cout << "Important one!" << std::endl; -Ty->dump(); const auto *RD = Ty->getAsCXXRecordDecl(); assert(RD && "The type must be a RecordDecl"); llvm::StringLiteral MethodName = @@ -2852,7 +2858,7 @@ Ty->dump(); // (if any). size_t ParamIndex = Params.size(); for (const ParmVarDecl *Param : InitMethod->parameters()) { - QualType ParamTy = Param->getType(); + QualType ParamTy = Param->getType(); // For lambda kernels the arguments to the OpenCL kernel are named // based on the position they have as fields in the definition of the // special type structure i.e __arg_field1, __arg_field2 and so on. @@ -2878,7 +2884,6 @@ Ty->dump(); handleAccessorType(Ty, RD, decl->getBeginLoc()); } LastParamIndex = ParamIndex; - std::cout << LastParamIndex << std::endl; return true; } @@ -2972,7 +2977,6 @@ Ty->dump(); SYCLKernelAttr::CreateImplicit(SemaSYCLRef.getASTContext())); SemaSYCLRef.addSyclDeviceDecl(KernelDecl); - //KernelDecl->dump(); } bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { @@ -2980,11 +2984,9 @@ Ty->dump(); return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType Ty) final { - ++StructDepth; - //StringRef Name = "_arg_struct"; - //addParam(Name, Ty); - //CurrentStruct = Params.back(); + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + // ++StructDepth; return true; } @@ -2994,7 +2996,8 @@ Ty->dump(); } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - --StructDepth; + // TODO + // --StructDepth; return true; } @@ -3010,15 +3013,6 @@ Ty->dump(); return true; } - bool handleStructType(ParmVarDecl *PD, QualType Ty) final { - StringRef Name = "_arg_struct"; - addParam(Name, Ty); - CurrentStruct = Params.back(); - return true; - } - - bool handleStructType(FieldDecl *, QualType) final { return true; } - bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { const auto *RecordDecl = FieldTy->getAsCXXRecordDecl(); @@ -3193,7 +3187,6 @@ Ty->dump(); return ArrayRef(std::begin(Params) + LastParamIndex, std::end(Params)); } - ParmVarDecl *getParentStructForCurrentField() { return CurrentStruct; } }; // This Visitor traverses the AST of the function with @@ -3647,11 +3640,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { SourceLocation LL = NewBody ? NewBody->getBeginLoc() : SourceLocation(); SourceLocation LR = NewBody ? NewBody->getEndLoc() : SourceLocation(); - CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, - FPOptionsOverride(), LL, LR)->dumpPretty(SemaSYCLRef.getASTContext()); -return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, + return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, FPOptionsOverride(), LL, LR); - } void annotateHierarchicalParallelismAPICalls() { @@ -4373,14 +4363,16 @@ return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts, class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { SyclKernelDeclCreator &DeclCreator; llvm::SmallVector BodyStmts; - llvm::SmallVector CurrentStructs; FunctionDecl *FreeFunc = nullptr; SourceLocation FreeFunctionSrcLoc; // Free function source location. llvm::SmallVector ArgExprs; - // Creates a DeclRefExpr to the ParmVar that represents an arbitrary - // free function parameter - Expr *createParamReferenceExpr(ParmVarDecl *FreeFunctionParameter) { + // Creates a DeclRefExpr to the ParmVar that represents the current free + // function parameter. + Expr *createParamReferenceExpr() { + ParmVarDecl *FreeFunctionParameter = + DeclCreator.getParamVarDeclsForCurrentField()[0]; + QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType(); Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr( FreeFunctionParameter, FreeFunctionParamType, VK_LValue, @@ -4389,14 +4381,6 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { return DRE; } - // Creates a DeclRefExpr to the ParmVar that represents the current free - // function parameter. - Expr *createParamReferenceExpr() { - ParmVarDecl *FreeFunctionParameter = - DeclCreator.getParamVarDeclsForCurrentField()[0]; - return createParamReferenceExpr(FreeFunctionParameter); - } - // Creates a DeclRefExpr to the ParmVar that represents the current pointer // parameter. Expr *createPointerParamReferenceExpr(QualType PointerTy) { @@ -4453,7 +4437,6 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { DRE = createReinterpretCastExpr( createGetAddressOf(DRE), SemaSYCLRef.getASTContext().getPointerType( OrigFunctionParameter->getType())); - DRE = createDerefOp(DRE); } @@ -4488,12 +4471,8 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { auto CallExpr = CallExpr::Create(Context, Fn, ArgExprs, ResultTy, VK, FreeFunctionSrcLoc, FPOptionsOverride()); BodyStmts.push_back(CallExpr); -CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, - {})->dumpPretty(Context); - return CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, {}); - } MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) { @@ -4510,17 +4489,15 @@ CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, void createSpecialMethodCall(const CXXRecordDecl *RD, StringRef MethodName, Expr *MemberBaseExpr, SmallVectorImpl &AddTo) { -CXXMethodDecl *Method = getMethodByName(RD, MethodName); + CXXMethodDecl *Method = getMethodByName(RD, MethodName); if (!Method) return; unsigned NumParams = Method->getNumParams(); llvm::SmallVector ParamDREs(NumParams); llvm::ArrayRef KernelParameters = DeclCreator.getParamVarDeclsForCurrentField(); - //std::cout << KernelParameters.size() << std::endl; for (size_t I = 0; I < NumParams; ++I) { QualType ParamType = KernelParameters[I]->getOriginalType(); - //ParamType->dump(); ParamDREs[I] = SemaSYCLRef.SemaRef.BuildDeclRefExpr( KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc); } @@ -4539,7 +4516,7 @@ CXXMethodDecl *Method = getMethodByName(RD, MethodName); public: static constexpr const bool VisitInsideSimpleContainers = false; - + FreeFunctionKernelBodyCreator(SemaSYCL &S, SyclKernelDeclCreator &DC, FunctionDecl *FF) : SyclKernelFieldHandler(S), DeclCreator(DC), FreeFunc(FF), @@ -4550,20 +4527,9 @@ CXXMethodDecl *Method = getMethodByName(RD, MethodName); DeclCreator.setBody(KernelBody); } - bool handleSyclSpecialType(FieldDecl *FD, QualType FieldTy) final { - // Being inside this function means there is a struct parameter to the free - // function kernel that contains a special type. -std::cout << "Body!" << std::endl; -FieldTy->dump(); - ParmVarDecl *ParentStruct = DeclCreator.getParentStructForCurrentField(); - // special_type_wrapper_map[ParentStruct->getType()] = true; - Expr *Base = createParamReferenceExpr(ParentStruct); - for (const auto &child : CurrentStructs) { - Base = buildMemberExpr(Base, child); - } - MemberExpr *MemberAccess = buildMemberExpr(Base, FD); - createSpecialMethodCall(FieldTy->getAsCXXRecordDecl(), InitMethodName, - MemberAccess, BodyStmts); + bool handleSyclSpecialType(FieldDecl *FD, QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -4582,8 +4548,6 @@ FieldTy->dump(); // wgm.__init(arg); // user_kernel(some arguments..., wgm, some arguments...); // } - std::cout << "Body!" << std::endl; - ParamTy->dump(); const auto *RecordDecl = ParamTy->getAsCXXRecordDecl(); AccessSpecifier DefaultConstructorAccess; auto DefaultConstructor = @@ -4616,8 +4580,8 @@ FieldTy->dump(); BodyStmts.push_back(DS); Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr( SpecialObjectClone, ParamTy, VK_PRValue, FreeFunctionSrcLoc); - createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr, - BodyStmts); + createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr, + BodyStmts); ArgExprs.push_back(MemberBaseExpr); return true; } @@ -4693,24 +4657,26 @@ FieldTy->dump(); } bool enterStruct(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final { - CurrentStructs.push_back(FD); + // TODO + unsupportedFreeFunctionParamType(); return true; } - bool enterStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, - QualType Ty) final { + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); return true; } bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { - CurrentStructs.pop_back(); + // TODO + unsupportedFreeFunctionParamType(); return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - ParmVarDecl *ParentStruct = DeclCreator.getParentStructForCurrentField(); - ArgExprs.push_back(SemaSYCLRef.SemaRef.BuildDeclRefExpr( - ParentStruct, ParentStruct->getType(), VK_PRValue, FreeFunctionSrcLoc)); + // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -4755,11 +4721,6 @@ FieldTy->dump(); unsupportedFreeFunctionParamType(); return true; } - FieldDecl *getCurrentStruct() { - assert(CurrentStructs.size() && - "Current free function parameter is not inside a structure!"); - return CurrentStructs.back(); - } }; // Kernels are only the unnamed-lambda feature if the feature is enabled, AND @@ -5049,6 +5010,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -5060,6 +5022,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { // TODO + unsupportedFreeFunctionParamType(); return true; } @@ -5556,25 +5519,22 @@ void SemaSYCL::constructFreeFunctionKernel(FunctionDecl *FD, StringRef NameStr) { if (!checkAndAddRegisteredKernelName(*this, FD, NameStr)) return; + SyclKernelArgsSizeChecker argsSizeChecker(*this, FD->getLocation(), false /*IsSIMDKernel*/); SyclKernelDeclCreator kernel_decl(*this, FD->getLocation(), FD->isInlined(), false /*IsSIMDKernel */, FD); + FreeFunctionKernelBodyCreator kernel_body(*this, kernel_decl, FD); + SyclKernelIntHeaderCreator int_header(*this, getSyclIntegrationHeader(), FD->getType(), FD); + SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter()); KernelObjVisitor Visitor{*this}; - Visitor.VisitFunctionParameters(FD, argsSizeChecker); - -Visitor.VisitFunctionParameters(FD, kernel_decl); - -Visitor.VisitFunctionParameters(FD, kernel_body); - -Visitor.VisitFunctionParameters(FD, int_header); - -Visitor.VisitFunctionParameters(FD, int_footer); + Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body, + int_header, int_footer); assert(getKernelFDPairs().back().first == FD && "OpenCL Kernel not found for free function entry"); @@ -7056,26 +7016,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { } } ParmListWithNamesOstream.flush(); - for (ParmVarDecl *Param : K.SyclKernel->parameters()) { - // if (FreeFunctionKernelBodyCreator::isSpecialTypeWrapper( - // Param->getType())) { - // this is a struct that contains a special type so its neither a - // special type nor a trivially copyable type. We therefore need to - // explicitly communicate to the runtime that this argument should be - // allowed as a free function kernel argument. We do this by defining - // a certain trait recognized by the runtime to be true. - O << "template <>\n"; - O << "struct " - "sycl::ext::oneapi::experimental::detail::is_explicitly_allowed_" - "arg<"; - Policy.SuppressTagKeyword = true; - - Param->getType().print(O, Policy); - Policy.SuppressTagKeyword = false; - O << "> {\n"; - O << " static constexpr bool value = true;\n};\n"; - //} - } FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); Policy.PrintAsCanonical = false; Policy.SuppressDefinition = true; @@ -7812,7 +7752,7 @@ StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap); Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get(); -OFD->setBody(OFDBody); + OFD->setBody(OFDBody); OFD->setNothrow(); Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD); diff --git a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp index 0ca1c234c9070..2b5d1f4190d21 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp @@ -44,16 +44,6 @@ template struct is_kernel { template inline constexpr bool is_kernel_v = is_kernel::value; -namespace detail { -template struct is_explicitly_allowed_arg { - static constexpr bool value = false; -}; - -template -inline constexpr bool is_explicitly_allowed_arg_v = - is_explicitly_allowed_arg::value; - -} // namespace detail } // namespace ext::oneapi::experimental } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 3506d1f34f493..fdd27ffe3f5cb 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -36,7 +36,6 @@ #include #include #include -#include #include #include #include @@ -1816,8 +1815,7 @@ class __SYCL_EXPORT handler { || (!is_same_type::value && std::is_pointer_v>) // USM || is_same_type::value // Interop - || is_same_type::value // Stream - || ext::oneapi::experimental::detail::is_explicitly_allowed_arg>::value; + || is_same_type::value; // Stream }; /// Sets argument for OpenCL interoperability kernels. From 01b6936ad176de1a4b4ece80bb62469c0e00b428 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Mon, 11 Aug 2025 08:57:18 -0700 Subject: [PATCH 03/13] Remove debugging artifacts from E2E test --- sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp b/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp index 674bcb6af977d..2f1bbcd086e86 100644 --- a/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp +++ b/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp @@ -1,7 +1,7 @@ // Ensure -fsycl-allow-device-dependencies can work with free function kernels. // REQUIRES: aspect-usm_shared_allocations -// RUN: %{build} --save-temps -o %t.out -fsycl-allow-device-image-dependencies +// RUN: %{build} -o %t.out -fsycl-allow-device-image-dependencies // RUN: %{run} %t.out #include From 3d3a6186b64bfb94d4bb395a8ce1dcc2c67e5f08 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Tue, 12 Aug 2025 12:25:52 -0700 Subject: [PATCH 04/13] Disable dear argument elimination for free function kernels --- .../lib/Transforms/IPO/DeadArgumentElimination.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 190be0aa3919d..44200c30343f2 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -15,7 +15,7 @@ // interprocedural passes, which add possibly-dead arguments or return values. // //===----------------------------------------------------------------------===// - +#include #include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -575,6 +575,18 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { return; } + // Do not modify arguments when the SYCL kernel is a free function kernel. + // In this case, the user sets the arguments of the kernel by themselves + // and dead argument elimination may interfere with their expectations. + bool FuncIsSyclFreeFunctionKernel = + F.hasFnAttribute("sycl-single-task-kernel") || + F.hasFnAttribute("sycl-nd-range-kernel"); + if (FuncIsSyclFreeFunctionKernel) { + std::cout << "Frozen!" << std::endl; + markFrozen(F); + return; + } + LLVM_DEBUG( dbgs() << "DeadArgumentEliminationPass - Inspecting callers for fn: " << F.getName() << "\n"); From 20bf5813db65a3adab8610e10be7b5c0172d6c88 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Tue, 12 Aug 2025 15:30:18 -0400 Subject: [PATCH 05/13] Update DeadArgumentElimination.cpp --- llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 44200c30343f2..61e84bb6808c2 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -15,7 +15,7 @@ // interprocedural passes, which add possibly-dead arguments or return values. // //===----------------------------------------------------------------------===// -#include + #include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -582,7 +582,6 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { F.hasFnAttribute("sycl-single-task-kernel") || F.hasFnAttribute("sycl-nd-range-kernel"); if (FuncIsSyclFreeFunctionKernel) { - std::cout << "Frozen!" << std::endl; markFrozen(F); return; } From a95f22d7e58b1c0ba128f04d3559814441e2aaf9 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Tue, 12 Aug 2025 15:30:42 -0400 Subject: [PATCH 06/13] Update free_function_kernels.cpp --- sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp b/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp index 2f1bbcd086e86..674bcb6af977d 100644 --- a/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp +++ b/sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp @@ -1,7 +1,7 @@ // Ensure -fsycl-allow-device-dependencies can work with free function kernels. // REQUIRES: aspect-usm_shared_allocations -// RUN: %{build} -o %t.out -fsycl-allow-device-image-dependencies +// RUN: %{build} --save-temps -o %t.out -fsycl-allow-device-image-dependencies // RUN: %{run} %t.out #include From 1060e14ef87536c715bd2a8a6dbf853dc4d4c9c0 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Wed, 13 Aug 2025 11:48:26 -0700 Subject: [PATCH 07/13] Add test for dead argument elimination disabling for free function kernels --- .../free_function_dead_arg_elimination.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp diff --git a/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp b/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp new file mode 100644 index 0000000000000..dd1699e90da90 --- /dev/null +++ b/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -fsycl-is-device -triple spir64 -emit-llvm --disable-passes %s -o %t.ll +// RUN: opt < %t.ll -passes=deadargelim-sycl -S | FileCheck %s + +// CHECK-NOT: !sycl_kernel_omit_args + +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void func1(int a, int b) { + a = 42; +} +// CHECK: define dso_local spir_kernel void @_Z19__sycl_kernel_func1ii(i32 noundef %__arg_a, i32 noundef %__arg_b) + +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]] +void func2(int a, int b) { + b = 42; +} +// CHECK: define dso_local spir_kernel void @_Z19__sycl_kernel_func2ii(i32 noundef %__arg_a, i32 noundef %__arg_b) From 2cb0ca7458bd64b35060e369d8e455c119c5d5f1 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Wed, 13 Aug 2025 15:00:16 -0400 Subject: [PATCH 08/13] Fix typo in RUN command --- clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp b/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp index dd1699e90da90..6ec0a31b271ba 100644 --- a/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp +++ b/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -fsycl-is-device -triple spir64 -emit-llvm --disable-passes %s -o %t.ll +// RUN: %clang_cc1 -fsycl-is-device -triple spir64 -emit-llvm -disable-llvm-passes %s -o %t.ll // RUN: opt < %t.ll -passes=deadargelim-sycl -S | FileCheck %s // CHECK-NOT: !sycl_kernel_omit_args From 964e5af19d4ab1f0db7d08b1017d513240b37a39 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Wed, 13 Aug 2025 12:04:35 -0700 Subject: [PATCH 09/13] Rename test and add an explnatory comment --- ...n_dead_arg_elimination.cpp => free_function_kernel_DAE.cpp} | 3 +++ 1 file changed, 3 insertions(+) rename clang/test/CodeGenSYCL/{free_function_dead_arg_elimination.cpp => free_function_kernel_DAE.cpp} (85%) diff --git a/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp b/clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp similarity index 85% rename from clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp rename to clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp index 6ec0a31b271ba..37459610f0096 100644 --- a/clang/test/CodeGenSYCL/free_function_dead_arg_elimination.cpp +++ b/clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp @@ -1,6 +1,9 @@ // RUN: %clang_cc1 -fsycl-is-device -triple spir64 -emit-llvm -disable-llvm-passes %s -o %t.ll // RUN: opt < %t.ll -passes=deadargelim-sycl -S | FileCheck %s +// This test verifies that the dead argument elimination optimization pass does not affect +// free function kernels. + // CHECK-NOT: !sycl_kernel_omit_args __attribute__((sycl_device)) From 6704b61919e4939f1d28bd4ccbba54219683e0fc Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Thu, 14 Aug 2025 09:33:15 -0700 Subject: [PATCH 10/13] Refactor tests --- .../CodeGenSYCL/free_function_kernel_DAE.cpp | 21 --------------- .../DeadArgElim/sycl-kernels-neg.ll | 26 ++++++++++++++++++- 2 files changed, 25 insertions(+), 22 deletions(-) delete mode 100644 clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp diff --git a/clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp b/clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp deleted file mode 100644 index 37459610f0096..0000000000000 --- a/clang/test/CodeGenSYCL/free_function_kernel_DAE.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: %clang_cc1 -fsycl-is-device -triple spir64 -emit-llvm -disable-llvm-passes %s -o %t.ll -// RUN: opt < %t.ll -passes=deadargelim-sycl -S | FileCheck %s - -// This test verifies that the dead argument elimination optimization pass does not affect -// free function kernels. - -// CHECK-NOT: !sycl_kernel_omit_args - -__attribute__((sycl_device)) -[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] -void func1(int a, int b) { - a = 42; -} -// CHECK: define dso_local spir_kernel void @_Z19__sycl_kernel_func1ii(i32 noundef %__arg_a, i32 noundef %__arg_b) - -__attribute__((sycl_device)) -[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]] -void func2(int a, int b) { - b = 42; -} -// CHECK: define dso_local spir_kernel void @_Z19__sycl_kernel_func2ii(i32 noundef %__arg_a, i32 noundef %__arg_b) diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll index 9f09f37de8012..ad140b4d62499 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll @@ -5,7 +5,8 @@ target triple = "spir64" ; This test ensures dead arguments are not eliminated -; from a global function that is not a SPIR kernel. +; from a global function that is not a SPIR kernel and +; from kernels that are free functions. ; CHECK-NOT: !sycl_kernel_omit_args @@ -32,6 +33,29 @@ define weak_odr void @ESIMDKernel(float %arg1, float %arg2) !sycl_explicit_simd ret void } +define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) "sycl-single-task-kernel"="0" { +; CHECK-LABEL: define {{[^@]+}}@FreeFuncKernelSingleTask +; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) #[[SINGLE_TASK_ATTR:[0-9]]] { +; CHECK-NEXT: call void @foo(float [[ARG1]]) +; CHECK-NEXT: ret void +; + call void @foo(float %arg1) + ret void +} + +define weak_odr spir_kernel void @FreeFuncKernelNdRange(float %arg1, float %arg2) "sycl-nd-range-kernel"="0" { +; CHECK-LABEL: define {{[^@]+}}@FreeFuncKernelNdRange +; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) #[[ND_RANGE_ATTR:[0-9]]] { +; CHECK-NEXT: call void @foo(float [[ARG1]]) +; CHECK-NEXT: ret void +; + call void @foo(float %arg1) + ret void +} + declare void @foo(float %arg) +; CHECK: attributes #[[SINGLE_TASK_ATTR]] = { "sycl-single-task-kernel"="0" } +; CHECK: attributes #[[ND_RANGE_ATTR]] = { "sycl-nd-range-kernel"="0" } + !0 = !{} From e6740fda5dac19a169273ec58be1985bbb955536 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Thu, 14 Aug 2025 09:56:34 -0700 Subject: [PATCH 11/13] Improve comments --- llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll index ad140b4d62499..9bea23f8a2585 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll @@ -5,8 +5,7 @@ target triple = "spir64" ; This test ensures dead arguments are not eliminated -; from a global function that is not a SPIR kernel and -; from kernels that are free functions. +; from a global function that is not a SPIR kernel. ; CHECK-NOT: !sycl_kernel_omit_args @@ -33,9 +32,12 @@ define weak_odr void @ESIMDKernel(float %arg1, float %arg2) !sycl_explicit_simd ret void } -define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) "sycl-single-task-kernel"="0" { +; The following two tests ensure that dead arguments are not eliminated +; from a free function kernel. + +define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) { ; CHECK-LABEL: define {{[^@]+}}@FreeFuncKernelSingleTask -; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) #[[SINGLE_TASK_ATTR:[0-9]]] { +; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) { ; CHECK-NEXT: call void @foo(float [[ARG1]]) ; CHECK-NEXT: ret void ; From 4f76afe0d97584d3f0d5b06ccce68d79d718b12d Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Thu, 14 Aug 2025 09:57:32 -0700 Subject: [PATCH 12/13] Improve comments --- llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll index 9bea23f8a2585..40656f6ce343b 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll @@ -35,7 +35,7 @@ define weak_odr void @ESIMDKernel(float %arg1, float %arg2) !sycl_explicit_simd ; The following two tests ensure that dead arguments are not eliminated ; from a free function kernel. -define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) { +define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) "sycl-single-task-kernel"="0" { ; CHECK-LABEL: define {{[^@]+}}@FreeFuncKernelSingleTask ; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) { ; CHECK-NEXT: call void @foo(float [[ARG1]]) From 54b372b7ecbac874d8abb157d935f8f5c316dad1 Mon Sep 17 00:00:00 2001 From: Lorenc Bushi Date: Thu, 14 Aug 2025 23:56:27 -0400 Subject: [PATCH 13/13] Update sycl-kernels-neg.ll --- llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll index 40656f6ce343b..5b371ae58f6ee 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll @@ -37,7 +37,7 @@ define weak_odr void @ESIMDKernel(float %arg1, float %arg2) !sycl_explicit_simd define weak_odr spir_kernel void @FreeFuncKernelSingleTask(float %arg1, float %arg2) "sycl-single-task-kernel"="0" { ; CHECK-LABEL: define {{[^@]+}}@FreeFuncKernelSingleTask -; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) { +; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]]) #[[SINGLE_TASK_ATTR]] { ; CHECK-NEXT: call void @foo(float [[ARG1]]) ; CHECK-NEXT: ret void ;