Skip to content

Commit

Permalink
[SYCL] Support TaskSequenceINTEL type as an extension type
Browse files Browse the repository at this point in the history
  • Loading branch information
premanandrao committed Feb 23, 2024
1 parent c9b017c commit 91e9f2b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
23 changes: 23 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,25 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSYCLTaskSequenceINTELType - Convert SYCL task_sequence type
/// which is represented as a pointer to a structure to LLVM extension type.
/// The expected representation is:
/// target("spirv.TaskSequenceINTEL", %element_type)
llvm::Type *CodeGenTypes::ConvertSYCLTaskSequenceINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();

assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st TaskSequenceINTEL template parameter must be a type");
assert((TemplateArgs.size() == 1) &&
"TaskSequenceINTEL must have one and only one template parameter");

llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.TaskSequenceINTEL", {CompTy});
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -654,6 +673,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_TaskSequenceINTEL") {
ResultType = ConvertSYCLTaskSequenceINTELType(RD);
break;
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class CodeGenTypes {
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSYCLTaskSequenceINTELType - Convert SYCL task_sequence type
/// which is represented as a pointer to a structure to LLVM extension type.
/// The expected representation is:
/// target("spirv.TaskSequenceINTEL", %element_type)
llvm::Type *ConvertSYCLTaskSequenceINTELType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
35 changes: 35 additions & 0 deletions clang/test/CodeGenSYCL/task_sequence.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s

// Test that SPIR-V codegen generates the expected LLVM struct name for the
// TaskSequenceINTEL type.

#include <stddef.h>
#include <stdint.h>

namespace __spv {
template <typename T>
struct __spirv_TaskSequenceINTEL;
}

struct S {
char c;
float f;
};

// CHECK: @_Z2f1{{.*}}(target("spirv.TaskSequenceINTEL", float)
void f1(__spv::__spirv_TaskSequenceINTEL<float> *task_seq) {}

// CHECK: @_Z2f2{{.*}}(target("spirv.TaskSequenceINTEL", i64)
void f2(__spv::__spirv_TaskSequenceINTEL<uint64_t> *task_seq) {}

// CHECK: @_Z2f3{{.*}}(target("spirv.TaskSequenceINTEL", i8)
void f3(__spv::__spirv_TaskSequenceINTEL<char> *task_seq) {}

// CHECK: @_Z2f4{{.*}}(target("spirv.TaskSequenceINTEL", i128)
void f4(__spv::__spirv_TaskSequenceINTEL<_BitInt(128)> *task_seq) {}

// CHECK: @_Z2f5{{.*}}(target("spirv.TaskSequenceINTEL", double)
void f5(__spv::__spirv_TaskSequenceINTEL<double> *task_seq) {}

// CHECK: @_Z2f6{{.*}}(target("spirv.TaskSequenceINTEL", %struct.S = type { i8, float })
void f6(__spv::__spirv_TaskSequenceINTEL<S> *task_seq) {}

0 comments on commit 91e9f2b

Please sign in to comment.