Skip to content

Commit

Permalink
version 0.0.24
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jan 10, 2024
1 parent 6f7cff8 commit 4b6a949
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@
count_num_params
)

__version__ = '0.0.24'
__version__ = '0.0.25'
38 changes: 25 additions & 13 deletions fjformer/checkpoint/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,19 @@ def save_state_to_file(
with open(path, "wb") as stream:
for key, value in pbar:
if gather_fns is not None:
callable_func = gather_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.")
value = callable_func(value) if callable_func is not None else value
if callable_func is None:
gather_functions_mismatch += 1
pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch)
try:
callable_func = gather_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.")
value = callable_func(value) if callable_func is not None else value
if callable_func is None:
gather_functions_mismatch += 1
except KeyError as k_err:
if mismatch_allowed:
gather_functions_mismatch += 1
else:
raise KeyError(k_err)
pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch)
value = get_dtype(value, float_dtype)
stream.write(packer.pack((key, to_bytes(value))))

Expand Down Expand Up @@ -228,12 +234,18 @@ def load_checkpoint(

tensor = from_bytes(None, value)
if shard_fns is not None:
callable_func = shard_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.")
tensor = callable_func(tensor) if callable_func is not None else tensor
if callable_func is None:
shard_functions_mismatch += 1
try:
callable_func = shard_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.")
tensor = callable_func(tensor) if callable_func is not None else tensor
if callable_func is None:
shard_functions_mismatch += 1
except KeyError as k_err:
if mismatch_allowed:
shard_functions_mismatch += 1
else:
raise KeyError(k_err)
flatten_state[key] = tensor
pbar.set_postfix(shard_functions_mismatch=shard_functions_mismatch)
if target is not None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="FJFormer",
version="0.0.24",
version="0.0.25",
author="Erfan Zare Chavoshi",
author_email="erfanzare82@yahoo.com",
long_description=long_description,
Expand Down

0 comments on commit 4b6a949

Please sign in to comment.