Skip to content

Commit

Permalink
Merge branch 'main' into sync-inject-api-links
Browse files Browse the repository at this point in the history
  • Loading branch information
aversey authored Nov 8, 2024
2 parents c0e5885 + 3b8fed0 commit 741b1ad
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 42 deletions.
58 changes: 35 additions & 23 deletions python/hsml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from hsml.engine import model_engine
from hsml.inference_batcher import InferenceBatcher
from hsml.inference_logger import InferenceLogger
from hsml.model_schema import ModelSchema
from hsml.predictor import Predictor
from hsml.resources import PredictorResources
from hsml.schema import Schema
from hsml.transformer import Transformer


Expand All @@ -54,7 +56,6 @@ def __init__(
program=None,
user_full_name=None,
model_schema=None,
training_dataset=None,
input_example=None,
framework=None,
model_registry_id=None,
Expand Down Expand Up @@ -84,7 +85,6 @@ def __init__(
self._input_example = input_example
self._framework = framework
self._model_schema = model_schema
self._training_dataset = training_dataset

# This is needed for update_from_response_json function to not overwrite name of the shared registry this model originates from
if not hasattr(self, "_shared_registry_project_name"):
Expand All @@ -95,17 +95,6 @@ def __init__(
self._model_engine = model_engine.ModelEngine()
self._feature_view = feature_view
self._training_dataset_version = training_dataset_version
if training_dataset_version is None and feature_view is not None:
if feature_view.get_last_accessed_training_dataset() is not None:
self._training_dataset_version = (
feature_view.get_last_accessed_training_dataset()
)
else:
warnings.warn(
"Provenance cached data - feature view provided, but training dataset version is missing",
util.ProvenanceWarning,
stacklevel=1,
)

@usage.method_logger
def save(
Expand All @@ -131,6 +120,39 @@ def save(
# Returns
`Model`: The model metadata object.
"""
if self._training_dataset_version is None and self._feature_view is not None:
if self._feature_view.get_last_accessed_training_dataset() is not None:
self._training_dataset_version = (
self._feature_view.get_last_accessed_training_dataset()
)
else:
warnings.warn(
"Provenance cached data - feature view provided, but training dataset version is missing",
util.ProvenanceWarning,
stacklevel=1,
)
if self._model_schema is None:
if (
self._feature_view is not None
and self._training_dataset_version is not None
):
all_features = self._feature_view.get_training_dataset_schema(
self._training_dataset_version
)
features, labels = [], []
for feature in all_features:
(labels if feature.label else features).append(feature.to_dict())
self._model_schema = ModelSchema(
input_schema=Schema(features) if features else None,
output_schema=Schema(labels) if labels else None,
)
else:
warnings.warn(
"Model schema cannot not be inferred without both the feature view and the training dataset version.",
util.ProvenanceWarning,
stacklevel=1,
)

return self._model_engine.save(
model_instance=self,
model_path=model_path,
Expand Down Expand Up @@ -375,7 +397,6 @@ def to_dict(self):
"inputExample": self._input_example,
"framework": self._framework,
"metrics": self._training_metrics,
"trainingDataset": self._training_dataset,
"environment": self._environment,
"program": self._program,
"featureView": util.feature_view_to_json(self._feature_view),
Expand Down Expand Up @@ -510,15 +531,6 @@ def model_schema(self):
def model_schema(self, model_schema):
self._model_schema = model_schema

@property
def training_dataset(self):
"""training_dataset of the model."""
return self._training_dataset

@training_dataset.setter
def training_dataset(self, training_dataset):
self._training_dataset = training_dataset

@property
def project_name(self):
"""project_name of the model."""
Expand Down
13 changes: 5 additions & 8 deletions python/hsml/utils/schema/columnar_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
from hsml.utils.schema.column import Column


try:
import hsfs
except ImportError:
pass

try:
import pyspark
except ImportError:
Expand All @@ -35,6 +30,10 @@ class ColumnarSchema:
"""Metadata object representing a columnar schema for a model."""

def __init__(self, columnar_obj=None):
from hsfs.training_dataset import (
TrainingDataset, # import performed here to prevent circular dependencies when importing ModelSchema
)

if isinstance(columnar_obj, list):
self.columns = self._convert_list_to_schema(columnar_obj)
elif isinstance(columnar_obj, pandas.DataFrame):
Expand All @@ -45,9 +44,7 @@ def __init__(self, columnar_obj=None):
columnar_obj, pyspark.sql.dataframe.DataFrame
):
self.columns = self._convert_spark_to_schema(columnar_obj)
elif importlib.util.find_spec("hsfs") is not None and isinstance(
columnar_obj, hsfs.training_dataset.TrainingDataset
):
elif isinstance(columnar_obj, TrainingDataset):
self.columns = self._convert_td_to_schema(columnar_obj)
else:
raise TypeError(
Expand Down
8 changes: 0 additions & 8 deletions python/tests/fixtures/model_fixtures.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -42,7 +41,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -69,7 +67,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -96,7 +93,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -123,7 +119,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -150,7 +145,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -177,7 +171,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand All @@ -197,7 +190,6 @@
"program": "program",
"user_full_name": "Full Name",
"model_schema": "model_schema.json",
"training_dataset": "training_dataset",
"input_example": "input_example.json",
"model_registry_id": 1,
"tags": [],
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ def assert_model(self, mocker, m, m_json, model_framework):
assert m.project_name == m_json["project_name"]
assert m.training_metrics == m_json["metrics"]
assert m._user_full_name == m_json["user_full_name"]
assert m.training_dataset == m_json["training_dataset"]
assert m.model_registry_id == m_json["model_registry_id"]

if model_framework is None:
Expand Down
4 changes: 2 additions & 2 deletions python/tests/utils/schema/test_columnar_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_constructor_default(self, mocker):
mock_convert_pandas_series_to_schema.assert_not_called()
mock_convert_spark_to_schema.assert_not_called()
mock_convert_td_to_schema.assert_not_called()
assert mock_find_spec.call_count == 2
assert mock_find_spec.call_count == 1

def test_constructor_list(self, mocker):
# Arrange
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_constructor_hsfs_td(self, mocker):
mock_convert_pandas_series_to_schema.assert_not_called()
mock_convert_spark_to_schema.assert_not_called()
mock_convert_td_to_schema.assert_called_once_with(columnar_obj)
assert mock_find_spec.call_count == 2
assert mock_find_spec.call_count == 1

# convert list to schema

Expand Down

0 comments on commit 741b1ad

Please sign in to comment.