diff --git a/generator/wrapper_gen.py b/generator/wrapper_gen.py index 790bf431b3..eb214b8714 100755 --- a/generator/wrapper_gen.py +++ b/generator/wrapper_gen.py @@ -1003,6 +1003,8 @@ def __cinit__(self): # this is our actual algorithm class for Python cdef class {{algo}}{{'('+iface[0]|lower+'__iface__)' if iface[0] else ''}}: + cdef tuple _params + ''' {{algo}} {{params_all|fmt('{}', 'sphinx', sep='\n')|indent(4)}} @@ -1017,6 +1019,17 @@ def __cinit__(self, self.c_ptr = mk_{{algo}}( {{params_all|fmt('{}', 'arg_cyext', sep=',\n')|indent(25+(algo|length))}} ) + current_locals = locals() + ordered_input_args = ''' + {{params_all|fmt('{}', 'name', sep=' ')|indent(0)}} + '''.strip().split() + self._params = tuple( + current_locals[arg] + for arg in ordered_input_args + ) + + def __reduce__(self): + return (self.__class__, self._params) {% if not iface[0] %} # the C++ manager__iface__ (de-templatized) diff --git a/tests/test_daal4py_serialization.py b/tests/test_daal4py_serialization.py new file mode 100644 index 0000000000..a626085c42 --- /dev/null +++ b/tests/test_daal4py_serialization.py @@ -0,0 +1,48 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import pickle +import unittest + +import numpy as np + +import daal4py + + +class Test(unittest.TestCase): + def test_serialization_of_qr(self): + obj_original = daal4py.qr(fptype="float") + obj_deserialized = pickle.loads(pickle.dumps(obj_original)) + + rng = np.random.default_rng(seed=123) + X = rng.standard_normal(size=(10, 5)) + + Q_orig = obj_original.compute(X).matrixQ + Q_deserialized = obj_deserialized.compute(X).matrixQ + np.testing.assert_almost_equal(Q_orig, Q_deserialized) + assert Q_orig.dtype == Q_deserialized.dtype + + def test_serialization_of_kmeans(self): + obj_original = daal4py.kmeans_init(nClusters=4) + obj_deserialized = pickle.loads(pickle.dumps(obj_original)) + + rng = np.random.default_rng(seed=123) + X = rng.standard_normal(size=(100, 20)) + + np.testing.assert_almost_equal( + obj_original.compute(X).centroids, + obj_deserialized.compute(X).centroids, + )