From 0ce296842b5e8a178a0b88ee4d2dbd85dacd7b29 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 17 Jun 2024 00:21:49 -0700 Subject: [PATCH 1/5] WIP --- Libraries/DirectMLX.h | 187 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index cee6d2ad..513574fd 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -4291,6 +4291,193 @@ namespace dml return output; } + + struct MultiHeadAttentionOutputs + { + Expression output; + Optional outputPresentKey; + Optional outputPresentValue; + }; + + inline MultiHeadAttentionOutputs MultiHeadAttention( + Expression query, + Expression key, + Expression value, + Expression stackedQueryKey, + Expression stackedKeyValue, + Expression stackedQueryKeyValue, + Expression bias, + Expression mask, + Expression relativePositionBias, + Expression pastKey, + Expression pastValue, + Expression pastSequenceLengths, + float scale, + float maskFilterValue, + uint32_t queryHeadCount, + uint32_t keyValueHeadCount, + DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType, + bool computeOutputPresentKeyValue, + Optional maxSequenceLength = {}) + { + assert(query || stackedQueryKey || stackedQueryKeyValue); + + detail::GraphBuilder* = nullptr; + + if (query) + { + assert(!stackedQueryKey); + assert(!stackedQueryKeyValue); + builder = query->Impl()->GetGraphBuilder(); + } + else if (stackedQueryKey) + { + assert(!query); + assert(!key); + assert(value); + assert(!stackedKeyValue); + assert(!stackedQueryKeyValue); + builder = stackedQueryKey->Impl()->GetGraphBuilder(); + } + else + { + assert(stackedQueryKeyValue); + assert(!query); + assert(!key); + assert(!value); + assert(!stackedQueryKey); + assert(!stackedKeyValue); + assert(!stackedQueryKeyValue); + builder = stackedQueryKeyValue->Impl()->GetGraphBuilder(); + } + + TensorDesc queryTensor = query ? query->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc keyTensor = key ? key->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc valueTensor = value ? value->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc stackedQueryKeyTensor = stackedQueryKey ? stackedQueryKey->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc stackedKeyValueTensor = stackedKeyValue ? stackedKeyValue->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc stackedQueryKeyValueTensor = stackedQueryKeyValue ? stackedQueryKeyValue->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc biasTensor = bias ? bias->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc maskTensor = mask ? mask->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc relativePositionBiasTensor = relativePositionBias ? relativePositionBias->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc pastKeyTensor = pastKey ? pastKey->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc pastValueTensor = pastValue ? pastValue->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc pastSequenceLengthsTensor = pastSequenceLengths ? pastSequenceLengths->Impl()->GetOutputDesc() : TensorDesc(); + + uint32_t batchSize; + uint32_t sequenceLength; + uint32_t headSize; + uint32_t valueHeadSize; + DML_TENSOR_DATA_TYPE dataType; + + if (query) + { + assert(queryTensor.sizes.size() >= 3); + batchSize = queryTensor.sizes[queryTensor.sizes.size() - 3]; + sequenceLength = queryTensor.sizes[queryTensor.sizes.size() - 2]; + headSize = queryTensor.sizes[queryTensor.sizes.size() - 1] / queryHeadCount; + dataType = queryTensor.dataType; + + if (value) + { + assert(valueTensor.sizes.size() >= 3); + valueHeadSize = valueTensor.sizes[valueTensor.sizes.size() - 1]; + } + else if (stackedKeyValue) + { + assert(valueTensor.sizes.size() >= 3); + valueHeadSize = valueTensor.sizes[valueTensor.sizes.size() - 1]; + } + } + else if (stackedQueryKey) + { + assert(stackedQueryKeyTensor.sizes.size() >= 5); + batchSize = stackedQueryKeyTensor.sizes[stackedQueryKeyTensor.sizes.size() - 5]; + sequenceLength = stackedQueryKeyTensor.sizes[stackedQueryKeyTensor.sizes.size() - 4]; + headSize = stackedQueryKeyTensor.sizes[stackedQueryKeyTensor.sizes.size() - 1]; + dataType = stackedQueryKeyTensor.dataType; + } + else + { + assert(stackedQueryKeyValue); + assert(stackedQueryKeyValueTensor.sizes.size() >= 5); + batchSize = stackedQueryKeyValueTensor.sizes[stackedQueryKeyValueTensor.sizes.size() - 5]; + sequenceLength = stackedQueryKeyValueTensor.sizes[stackedQueryKeyValueTensor.sizes.size() - 4]; + headSize = stackedQueryKeyValueTensor.sizes[stackedQueryKeyValueTensor.sizes.size() - 1]; + valueHeadSize = headSize; + dataType = stackedQueryKeyValueTensor.dataType; + } + + assert(inputGradientTensor.sizes.size() > 1); + + uint32_t outputHiddenSize = valueHeadSize * queryHeadCount; + + TensorDesc::Dimensions outputSizes({batchSize, sequenceLength, outputHiddenSize}); + TensorDesc outputTensor = TensorDesc(dataType, outputSizes, builder->GetTensorPolicy()); + + TensorDesc outputPresentKeyTensor; + TensorDesc outputPresentValueTensor; + if (computeOutputPresentKeyValue) + { + assert(maxSequenceLength); + + TensorDesc::Dimensions outputPresentKeySizes({batchSize, keyValueHeadCount, *maxSequenceLength, headSize}); + outputPresentKeyTensor = TensorDesc(dataType, outputPresentKeySizes, builder->GetTensorPolicy()); + + TensorDesc::Dimensions outputPresentValueSizes({batchSize, keyValueHeadCount, *maxSequenceLength, valueHeadSize}); + outputPresentValueTensor = TensorDesc(dataType, outputPresentValueSizes, builder->GetTensorPolicy()); + } + + DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC desc = {}; + desc.QueryTensor = query ? queryTensor.AsPtr() : nullptr; + desc.KeyTensor = key ? keyTensor.AsPtr() : nullptr; + desc.ValueTensor = value ? valueTensor.AsPtr() : nullptr; + desc.StackedQueryKeyTensor = stackedQueryKey ? stackedQueryKeyTensor.AsPtr() : nullptr; + desc.StackedKeyValueTensor = stackedKeyValue ? stackedKeyValueTensor.AsPtr() : nullptr; + desc.StackedQueryKeyValueTensor = stackedQueryKeyValue ? stackedQueryKeyValueTensor.AsPtr() : nullptr; + desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr; + desc.MaskTensor = mask ? maskTensor.AsPtr() : nullptr; + desc.RelativePositionBiasTensor = relativePositionBias ? relativePositionBiasTensor.AsPtr() : nullptr; + desc.PastKeyTensor = pastKey ? pastKeyTensor.AsPtr() : nullptr; + desc.PastValueTensor = pastValue ? pastValueTensor.AsPtr() : nullptr; + desc.PastSequenceLengthsTensor = pastSequenceLengths ? pastSequenceLengthsTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.OutputPresentKeyTensor = computeOutputPresentKeyValue ? outputPresentKeyTensor.AsPtr() : nullptr; + desc.OutputPresentValueTensor = computeOutputPresentKeyValue ? outputPresentValueTensor.AsPtr() : nullptr; + desc.Scale = scale; + desc.MaskFilterValue = maskFilterValue; + desc.QueryHeadCount = queryHeadCount; + desc.KeyValueHeadCount = keyValueHeadCount; + desc.MaskType = maskType; + + detail::NodeOutput* const inputs[] = { + query ? query->Impl() : nullptr, + key ? key->Impl() : nullptr, + value ? value->Impl() : nullptr, + stackedQueryKey ? stackedQueryKey->Impl() : nullptr, + stackedKeyValue ? stackedKeyValue->Impl() : nullptr, + stackedQueryKeyValue ? stackedQueryKeyValue->Impl() : nullptr, + bias ? bias->Impl() : nullptr, + mask ? mask->Impl() : nullptr, + relativePositionBias ? relativePositionBias->Impl() : nullptr, + pastKey ? pastKey->Impl() : nullptr, + pastValue ? pastValue->Impl() : nullptr, + pastSequenceLengths ? pastSequenceLengths->Impl() : nullptr, + }; + detail::NodeID node = builder->CreateOperatorNode(static_cast(DML_OPERATOR_MULTIHEAD_ATTENTION1), &desc, inputs); + + MultiHeadAttentionOutputs outputs {}; + + outputs.output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + if (computeOutputPresentKeyValue) + { + outputs.outputPresentKey = builder->CreateNodeOutput(node, 1, std::move(outputPresentKeyTensor)); + outputs.outputPresentValue = builder->CreateNodeOutput(node, 2, std::move(outputPresentValueTensor)); + } + + return outputs; + } #endif // Reinterprets the memory of a tensor with a different type and dimensions (analogously to using From 99da23eb81f5d563572fef9a838ecc31f4a5d4c1 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 17 Jun 2024 00:23:18 -0700 Subject: [PATCH 2/5] WIP --- Libraries/DirectMLX.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index 513574fd..e8605a58 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -4322,7 +4322,7 @@ namespace dml { assert(query || stackedQueryKey || stackedQueryKeyValue); - detail::GraphBuilder* = nullptr; + detail::GraphBuilder* builder = nullptr; if (query) { From 6f4d6753953bdd3b10eb00b59befde6225f0959a Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 17 Jun 2024 00:24:05 -0700 Subject: [PATCH 3/5] WIP --- Libraries/DirectMLX.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index e8605a58..7599deb5 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -4300,18 +4300,18 @@ namespace dml }; inline MultiHeadAttentionOutputs MultiHeadAttention( - Expression query, - Expression key, - Expression value, - Expression stackedQueryKey, - Expression stackedKeyValue, - Expression stackedQueryKeyValue, - Expression bias, - Expression mask, - Expression relativePositionBias, - Expression pastKey, - Expression pastValue, - Expression pastSequenceLengths, + Optional query, + Optional key, + Optional value, + Optional stackedQueryKey, + Optional stackedKeyValue, + Optional stackedQueryKeyValue, + Optional bias, + Optional mask, + Optional relativePositionBias, + Optional pastKey, + Optional pastValue, + Optional pastSequenceLengths, float scale, float maskFilterValue, uint32_t queryHeadCount, From b931d8dde9f58f5e3c3b005123b7d8e8947dfbdc Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 17 Jun 2024 00:24:43 -0700 Subject: [PATCH 4/5] WIP --- Libraries/DirectMLX.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index 7599deb5..d71026c1 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -4408,8 +4408,6 @@ namespace dml dataType = stackedQueryKeyValueTensor.dataType; } - assert(inputGradientTensor.sizes.size() > 1); - uint32_t outputHiddenSize = valueHeadSize * queryHeadCount; TensorDesc::Dimensions outputSizes({batchSize, sequenceLength, outputHiddenSize}); From e67c4f89a130684392d4312315e01ae5596a2887 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 17 Jun 2024 01:31:15 -0700 Subject: [PATCH 5/5] WIP --- Libraries/DirectMLX.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index d71026c1..da29c09c 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -4381,7 +4381,7 @@ namespace dml if (value) { assert(valueTensor.sizes.size() >= 3); - valueHeadSize = valueTensor.sizes[valueTensor.sizes.size() - 1]; + valueHeadSize = valueTensor.sizes[valueTensor.sizes.size() - 1] / keyValueHeadCount; } else if (stackedKeyValue) {