Skip to content

Commit

Permalink
Custom objects support when pickling keras models
Browse files Browse the repository at this point in the history
  • Loading branch information
mthiboust committed Jun 17, 2024
1 parent f6cf6a0 commit 8625e1b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
25 changes: 25 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from keras.src.models.functional import Functional
from keras.src.models.model import Model
from keras.src.models.model import model_from_json
from keras.src.saving.object_registration import register_keras_serializable


@register_keras_serializable(package="MyLayers", name="CustomDense")
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.dense = layers.Dense(units)

def call(self, x):
return self.dense(x)

def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config


def _get_model():
Expand Down Expand Up @@ -68,6 +85,13 @@ def _get_model_multi_outputs_dict():
return model


def _get_model_custom_layer():
x = Input(shape=(3,), name="input_a")
output_a = CustomDense(10, name="output_a")(x)
model = Model(x, output_a)
return model


@pytest.mark.requires_trainable_backend
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
Expand Down Expand Up @@ -127,6 +151,7 @@ def call(self, x):
("single_list_output_2", _get_model_single_output_list),
("single_list_output_3", _get_model_single_output_list),
("single_list_output_4", _get_model_single_output_list),
("custom_layer", _get_model_custom_layer),
)
def test_functional_pickling(self, model_fn):
model = model_fn()
Expand Down
22 changes: 17 additions & 5 deletions keras/src/saving/keras_saveable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import io
import pickle

from keras.src.saving.object_registration import get_custom_objects


class KerasSaveable:
Expand All @@ -14,12 +17,16 @@ def _obj_type(self):
)

@classmethod
def _unpickle_model(cls, bytesio):
def _unpickle_model(cls, model_buf, custom_objects_buf):
import keras.src.saving.saving_lib as saving_lib

# pickle is not safe regardless of what you do.
custom_objects = pickle.load(custom_objects_buf)
return saving_lib._load_model_from_fileobj(
bytesio, custom_objects=None, compile=True, safe_mode=False
model_buf,
custom_objects=custom_objects,
compile=True,
safe_mode=False,
)

def __reduce__(self):
Expand All @@ -30,9 +37,14 @@ def __reduce__(self):
keras saving library."""
import keras.src.saving.saving_lib as saving_lib

buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, buf, "h5")
model_buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, model_buf, "h5")

custom_objects_buf = io.BytesIO()
pickle.dump(get_custom_objects(), custom_objects_buf)
custom_objects_buf.seek(0)

return (
self._unpickle_model,
(buf,),
(model_buf, custom_objects_buf),
)

0 comments on commit 8625e1b

Please sign in to comment.