Skip to content

Commit

Permalink
[Mosaic] Fix lowering for _dot_general_lowering_rule to match the n…
Browse files Browse the repository at this point in the history
…ew `vector.MultiDimReductionOp` signature.

PiperOrigin-RevId: 662933072
  • Loading branch information
bchetioui authored and jax authors committed Aug 14, 2024
1 parent b0a144a commit df2e9c3
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,9 +1460,7 @@ def _dot_general_lowering_rule(
ir.Attribute.parse("#vector.kind<add>"),
arith.MulFOp(x, y),
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)]
),
[1]
)
return vector.ShapeCastOp(out_type, red).result

Expand Down

0 comments on commit df2e9c3

Please sign in to comment.