diff --git a/fjformer/__init__.py b/fjformer/__init__.py index a2075d5..9884bb0 100644 --- a/fjformer/__init__.py +++ b/fjformer/__init__.py @@ -151,4 +151,4 @@ count_num_params ) -__version__ = '0.0.24' +__version__ = '0.0.25' diff --git a/fjformer/checkpoint/streamer.py b/fjformer/checkpoint/streamer.py index 3ea7c9c..60a9f24 100644 --- a/fjformer/checkpoint/streamer.py +++ b/fjformer/checkpoint/streamer.py @@ -95,13 +95,19 @@ def save_state_to_file( with open(path, "wb") as stream: for key, value in pbar: if gather_fns is not None: - callable_func = gather_fns[key] - if callable_func is None and not mismatch_allowed: - raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.") - value = callable_func(value) if callable_func is not None else value - if callable_func is None: - gather_functions_mismatch += 1 - pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch) + try: + callable_func = gather_fns[key] + if callable_func is None and not mismatch_allowed: + raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.") + value = callable_func(value) if callable_func is not None else value + if callable_func is None: + gather_functions_mismatch += 1 + except KeyError as k_err: + if mismatch_allowed: + gather_functions_mismatch += 1 + else: + raise KeyError(k_err) + pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch) value = get_dtype(value, float_dtype) stream.write(packer.pack((key, to_bytes(value)))) @@ -228,12 +234,18 @@ def load_checkpoint( tensor = from_bytes(None, value) if shard_fns is not None: - callable_func = shard_fns[key] - if callable_func is None and not mismatch_allowed: - raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.") - tensor = callable_func(tensor) if callable_func is not None else tensor - if callable_func is None: - shard_functions_mismatch += 1 + try: + callable_func = shard_fns[key] + if callable_func is None and not mismatch_allowed: + raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.") + tensor = callable_func(tensor) if callable_func is not None else tensor + if callable_func is None: + shard_functions_mismatch += 1 + except KeyError as k_err: + if mismatch_allowed: + shard_functions_mismatch += 1 + else: + raise KeyError(k_err) flatten_state[key] = tensor pbar.set_postfix(shard_functions_mismatch=shard_functions_mismatch) if target is not None: diff --git a/setup.py b/setup.py index 465c41a..10d902a 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="FJFormer", - version="0.0.24", + version="0.0.25", author="Erfan Zare Chavoshi", author_email="erfanzare82@yahoo.com", long_description=long_description,