Skip to content

Commit

Permalink
[numpy] Fix test failures under NumPy 2.0.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662528924
Change-Id: I0a1215f5c7d097cc37c3953a05cd4aebab4701a4
  • Loading branch information
hawkinsp authored and copybara-github committed Aug 13, 2024
1 parent 8d76656 commit d9c4319
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions acme/wrappers/single_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def _convert_value(nested_value: types.Nest) -> types.Nest:

def _convert_single_value(value):
if value is not None:
value = np.array(value, copy=False)
value = np.asarray(value)
if np.issubdtype(value.dtype, np.float64):
value = np.array(value, copy=False, dtype=np.float32)
value = np.asarray(value, dtype=np.float32)
elif np.issubdtype(value.dtype, np.int64):
value = np.array(value, copy=False, dtype=np.int32)
value = np.asarray(value, dtype=np.int32)
return value

return tree.map_structure(_convert_single_value, nested_value)

0 comments on commit d9c4319

Please sign in to comment.