From 23ea67388ffa8834f796c35000a69d0a46c97635 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 10 Dec 2024 14:56:32 -0800 Subject: [PATCH] Fix anyvalue marshalling for matrix types. --- .../slang/slang-ir-any-value-marshalling.cpp | 41 ++++++++++++++----- .../anyvalue-matrix-layout.slang | 30 ++++++++++++++ 2 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 tests/language-feature/anyvalue-matrix-layout.slang diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 163c5a8089..73eb04bc59 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -166,17 +166,36 @@ struct AnyValueMarshallingContext auto matrixType = static_cast(dataType); auto colCount = getIntVal(matrixType->getColumnCount()); auto rowCount = getIntVal(matrixType->getRowCount()); - for (IRIntegerValue i = 0; i < colCount; i++) + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto col = builder->emitElementAddress( - concreteTypedVar, - builder->getIntValue(builder->getIntType(), i)); - for (IRIntegerValue j = 0; j < rowCount; j++) + for (IRIntegerValue i = 0; i < colCount; i++) + { + for (IRIntegerValue j = 0; j < rowCount; j++) + { + auto row = builder->emitElementAddress( + concreteTypedVar, + builder->getIntValue(builder->getIntType(), j)); + auto element = builder->emitElementAddress( + row, + builder->getIntValue(builder->getIntType(), i)); + emitMarshallingCode(builder, context, element); + } + } + } + else + { + for (IRIntegerValue i = 0; i < rowCount; i++) { - auto element = builder->emitElementAddress( - col, - builder->getIntValue(builder->getIntType(), j)); - emitMarshallingCode(builder, context, element); + auto row = builder->emitElementAddress( + concreteTypedVar, + builder->getIntValue(builder->getIntType(), i)); + for (IRIntegerValue j = 0; j < colCount; j++) + { + auto element = builder->emitElementAddress( + row, + builder->getIntValue(builder->getIntType(), j)); + emitMarshallingCode(builder, context, element); + } } } break; @@ -762,9 +781,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) auto elementType = matrixType->getElementType(); auto colCount = getIntVal(matrixType->getColumnCount()); auto rowCount = getIntVal(matrixType->getRowCount()); - for (IRIntegerValue i = 0; i < colCount; i++) + for (IRIntegerValue i = 0; i < rowCount; i++) { - for (IRIntegerValue j = 0; j < rowCount; j++) + for (IRIntegerValue j = 0; j < colCount; j++) { offset = _getAnyValueSizeRaw(elementType, offset); if (offset < 0) diff --git a/tests/language-feature/anyvalue-matrix-layout.slang b/tests/language-feature/anyvalue-matrix-layout.slang new file mode 100644 index 0000000000..351eec81be --- /dev/null +++ b/tests/language-feature/anyvalue-matrix-layout.slang @@ -0,0 +1,30 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +interface IFoo +{ + float getVal(); +} + +struct Foo : IFoo +{ + column_major float3x2 m; + float getVal() + { + return m[2][0]; + } +} + +//TEST_INPUT: type_conformance Foo:IFoo = 0 + +//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0], stride=4) +RWStructuredBuffer gFoo; + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK: 3.0 + outputBuffer[0] = gFoo[0].getVal(); +} \ No newline at end of file