Skip to content

Commit

Permalink
Avoid depending on JAX internals, which are about to change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688958442
  • Loading branch information
dougalm authored and The LAST Authors committed Oct 23, 2024
1 parent def72ab commit 37fef33
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions last/lattices.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def step(weight_fn, carry, inputs):
#
# For the tropical semiring, this should be equivalent to no remat.
def save_small(prim, *args, **params):
if prim.multiple_results:
return False
y, _ = prim.abstract_eval(*args, **params)
greater_than_1_dims = len([None for i in y.shape if i > 1])
save = greater_than_1_dims <= (len(batch_dims) + 1)
Expand Down

0 comments on commit 37fef33

Please sign in to comment.