diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1706b8f201a4..4db430ecde9e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1460,9 +1460,7 @@ def _dot_general_lowering_rule( ir.Attribute.parse("#vector.kind"), arith.MulFOp(x, y), acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)] - ), + [1] ) return vector.ShapeCastOp(out_type, red).result