Skip to content

Commit

Permalink
Return sharding whenever the leaf has .sharding attribute, not checki…
Browse files Browse the repository at this point in the history
…ng explicitly for jax.Array.

PiperOrigin-RevId: 565770203
  • Loading branch information
IvyZX authored and Flax Authors committed Sep 15, 2023
1 parent 654ae1a commit ddb04bd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any:
"""

def find_sharding(x):
if isinstance(x, jax.Array):
if hasattr(x, 'sharding'):
return x.sharding
return None

Expand Down

0 comments on commit ddb04bd

Please sign in to comment.