Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Dec 11, 2024
1 parent dbd8c05 commit 57e44ca
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -988,23 +988,21 @@ module {
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0 floordiv 4)>

func.func @fuse_and_collapse(%arg0: tensor<3x4xi32>) -> tensor<2x12xi32> {
%1 = tensor.empty() : tensor<2x3x4xi32>
func.func @fuse_and_collapse(%arg0: tensor<3x4xindex>) -> tensor<2x12xindex> {
%1 = tensor.empty() : tensor<2x3x4xindex>
// CHECK: linalg.generic {
// CHECK: %[[INDEX1:[a-zA-Z0-9_]+]] = linalg.index 1 : index
// CHECK-NEXT: %[[MAP:[a-zA-Z0-9_]+]] = affine.apply #map1(%[[INDEX1]])
// CHECK-NEXT: %[[CAST:[a-zA-Z0-9_]+]] = arith.index_cast %[[MAP]] : index to i32
// CHECK-NEXT: linalg.yield %[[CAST]] : i32
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0: tensor<3x4xi32>) outs(%1 : tensor<2x3x4xi32>) {
^bb0(%in: i32, %out: i32):
// CHECK-NEXT: linalg.yield %[[MAP]] : index
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0: tensor<3x4xindex>) outs(%1 : tensor<2x3x4xindex>) {
^bb0(%in: index, %out: index):
%3 = linalg.index 1 : index
%cast = arith.index_cast %3 : index to i32
linalg.yield %cast : i32
} -> tensor<2x3x4xi32>
%7 = tensor.empty() : tensor<2x12xi32>
%8 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<2x3x4xi32>) outs(%7 : tensor<2x12xi32>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
} -> tensor<2x12xi32>
return %8 : tensor<2x12xi32>
linalg.yield %3: index
} -> tensor<2x3x4xindex>
%7 = tensor.empty() : tensor<2x12xindex>
%8 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<2x3x4xindex>) outs(%7 : tensor<2x12xindex>) {
^bb0(%in: index, %out: index):
linalg.yield %in : index
} -> tensor<2x12xindex>
return %8 : tensor<2x12xindex>
}

0 comments on commit 57e44ca

Please sign in to comment.