Skip to content

Commit

Permalink
GPU Data-tiled multi-mma: subgroup dimensions should be outer (#19521)
Browse files Browse the repository at this point in the history
This was already the idea, but there was an accidental exception: in the
accumulator tensor, if there was both a `unroll_m` dimension and
`subgroup_n` dimension, then the `subgroup_n` dimension wasn't on the
outside of `unroll_m` as it was meant to be.

Noticed this when it required corresponding strides in the ukernel.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Dec 19, 2024
1 parent 16097c1 commit ed9a028
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ func.func @set_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand All @@ -255,9 +255,9 @@ func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {

// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>)
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
Expand Down Expand Up @@ -298,9 +298,9 @@ func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32() {
}
// CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<?x?x8x4x2x4x16x4xf32>)
// CHECK-SAME: ins(%{{.+}} : tensor<?x?x4x8x2x4x16x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x4x2x16xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK-SAME: : tensor<?x?x8x4x4x4x2x16xf32> into tensor<?x?x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
Expand Down Expand Up @@ -362,7 +362,7 @@ func.func @matmul_lowering_MFMA_F32_16x16x4_F32() {
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
Expand Down Expand Up @@ -422,7 +422,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32() {
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
Expand Down Expand Up @@ -528,8 +528,8 @@ func.func @set_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xi32>)
// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand All @@ -553,9 +553,9 @@ func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {

// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>)
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
Expand Down Expand Up @@ -618,7 +618,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8() {
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xi32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xi32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
Expand Down Expand Up @@ -1124,7 +1124,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() {
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x8xf8E4M3FNUZ>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x8xf8E4M3FNUZ>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
Expand Down Expand Up @@ -1184,7 +1184,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
if (mma.getUnrollN() > 1) {
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getSubgroupsN() > 1) {
expand(swizzle, 1, {Kind::CrossThread, mma.getSubgroupsN()});
}
if (mma.getUnrollM() > 1) {
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getSubgroupsN() > 1) {
expand(swizzle, 1, {Kind::CrossThread, mma.getSubgroupsN()});
}
if (mma.getSubgroupsM() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsM()});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,11 @@ hal.executable public @main {
// CHECK: gpu.barrier
// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x4xf32>
// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read %[[B_ALLOC]]{{.*}} vector<2x1x1x4xf32>
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_00_0:.+]] = vector.extract %[[C_READ]][0, 0, 0, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_01_0:.+]] = vector.extract %[[C_READ]][0, 0, 1, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_70_0:.+]] = vector.extract %[[C_READ]][7, 0, 0, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_71_0:.+]] = vector.extract %[[C_READ]][7, 0, 1, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_00_0:.+]] = vector.extract %[[C_READ]][0, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_01_0:.+]] = vector.extract %[[C_READ]][0, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_70_0:.+]] = vector.extract %[[C_READ]][7, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_71_0:.+]] = vector.extract %[[C_READ]][7, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[A_EXTRACT00:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0] : f32 from vector<8x1x1x4xf32>
// CHECK-DAG: %[[A_EXTRACT01:.+]] = vector.extract %[[A_READ]][0, 0, 0, 1] : f32 from vector<8x1x1x4xf32>
// CHECK-DAG: %[[A_EXTRACT02:.+]] = vector.extract %[[A_READ]][0, 0, 0, 2] : f32 from vector<8x1x1x4xf32>
Expand Down

0 comments on commit ed9a028

Please sign in to comment.