diff --git a/python/hsml/util.py b/python/hsml/util.py index 96380b6f4..c47733d50 100644 --- a/python/hsml/util.py +++ b/python/hsml/util.py @@ -28,16 +28,6 @@ import pandas as pd from hsml import client from hsml.constants import DEFAULT, MODEL, PREDICTOR -from hsml.model import Model as BaseModel -from hsml.predictor import Predictor as BasePredictor -from hsml.python.model import Model as PyModel -from hsml.python.predictor import Predictor as PyPredictor -from hsml.sklearn.model import Model as SkLearnModel -from hsml.sklearn.predictor import Predictor as SkLearnPredictor -from hsml.tensorflow.model import Model as TFModel -from hsml.tensorflow.predictor import Predictor as TFPredictor -from hsml.torch.model import Model as TorchModel -from hsml.torch.predictor import Predictor as TorchPredictor from six import string_types @@ -105,6 +95,11 @@ def default(self, obj): # pylint: disable=E0202 def set_model_class(model): + from hsml.model import Model as BaseModel + from hsml.python.model import Model as PyModel + from hsml.sklearn.model import Model as SkLearnModel + from hsml.tensorflow.model import Model as TFModel + from hsml.torch.model import Model as TorchModel if "href" in model: _ = model.pop("href") if "type" in model: # backwards compatibility @@ -236,6 +231,16 @@ def validate_metrics(metrics): def get_predictor_for_model(model, **kwargs): + from hsml.model import Model as BaseModel + from hsml.python.model import Model as PyModel + from hsml.sklearn.model import Model as SkLearnModel + from hsml.tensorflow.model import Model as TFModel + from hsml.torch.model import Model as TorchModel + from hsml.torch.predictor import Predictor as TorchPredictor + from hsml.predictor import Predictor as BasePredictor + from hsml.tensorflow.predictor import Predictor as TFPredictor + from hsml.sklearn.predictor import Predictor as SkLearnPredictor + from hsml.python.predictor import Predictor as PyPredictor if not isinstance(model, BaseModel): raise ValueError( "model is of type {}, but an instance of {} class is expected".format( diff --git a/python/pyproject.toml b/python/pyproject.toml index 94aecc7fe..9452f5061 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -94,16 +94,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] exclude = ["tests*"] -include = [ - "../README.md", - "../LICENSE", - "hopsworks", - "hopsworks.*", - "hsfs", - "hsfs.*", - "hsml", - "hsml.*", -] +include = ["../README.md", "../LICENSE", "hopsworks*", "hsfs*", "hsml*"] [tool.setuptools.dynamic] version = { attr = "hopsworks.version.__version__" }