-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom serialization for non-user types and non-serializable types for Hera runner (Parameter/Artifact inputs and outputs) #1166
Comments
Also related and a complete blocker to using the new decorators - I have no way to output a bytes Artifact from a template - Using class ModelTrainingInput(Input):
X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)]
y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)]
model: Annotated[Path, Artifact(name="model", output=True)] Gets the following error when building the workflow
And using class ModelTrainingOutput(Output):
model: Annotated[bytes, Artifact(name="model", archive=NoneArchiveStrategy())]
@w.script()
def model_training(model_training_input: ModelTrainingInput) -> ModelTrainingOutput:
X_train = np.array(model_training_input.X_train)
y_train = pd.Series(model_training_input.y_train)
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
return ModelTrainingOutput(model={"model": pickle.dumps(model)}) gets the following error when running on the cluster
Workaround is to use the old syntax with an "output" artifact in the function inputs i.e. @script(constructor="runner")
def model_training(
X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)],
y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)],
model_path: Annotated[Path, Artifact(name="model", archive=NoneArchiveStrategy(), output=True)],
): And doing model_path.write_bytes(pickle.dumps(model)) |
Good idea from #903 - if the loader is a def fan_in(*, responses: Annotated[list[Magic], Parameter(loader=Magic)]) Otherwise the loader could be any |
Is your feature request related to a problem? Please describe.
Tried to use
pandas.DataFrame
for outputs, got error:Python:
Pandas DataFrames have a
to_json
method which would make things easier, but I have no way to tell theserialize
function inhera.shared.serialization
what to do with DataFrames. I also can't change the class code, hence "non-user" type (I could subclass it though?).Describe the solution you'd like
A clear and concise description of what you want to happen.
An easy way to plug in the "how" for serializing custom types in the runner, e.g. as part of the type annotation, or a global setter such as
global_config.serializer = my_serializer
, or maybe in theRunnerScriptConstructor
? (Needs some more thought)Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
strs
and use theDataFrame.to_json
methodRunnerScriptConstructor
so I can use my own serialize functionAdditional context
Add any other context or screenshots about the feature request here.
The text was updated successfully, but these errors were encountered: