-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for mlflow #77
base: main
Are you sure you want to change the base?
Conversation
A status from me. I got a bit further with the signature (thanks to @sadamov), but I am seeing an error that I am currently trying to understand. @elbdmi will also have a look at this. The def log_model(self, data_module, model):
input_example = self.create_input_example(data_module)
with torch.no_grad():
model_output = model.common_step(input_example)[0] # expects batch, returns tuple (ar_model)
#TODO: Are we sure we can hardcode the input names?
signature = infer_signature(
{name: tensor.cpu().numpy() for name, tensor in zip(['init_states', 'target_states', 'forcing', 'target_times'], input_example)},
model_output.cpu().numpy()
)
mlflow.pytorch.log_model(
model,
"model",
input_example=input_example[0].cpu().numpy(),
signature=signature
)
def create_input_example(self, data_module):
if data_module.val_dataset is None:
data_module.setup(stage="fit")
data_loader = data_module.train_dataloader()
batch_sample = next(iter(data_loader))
return batch_sample For example When I log the model I do training_logger.log_model(data_module, model) but it fails to validate "serving the input example" and print the full tensor:
Input and output data examples are uploaded to mlflow, see this example. |
I think it's best to skip the log model feature of mlflow for now. Discarding the log model feature, means mlflow only will be supported for training (like wandb). When I call
Instead, I can infer the signature and give that, which kind of works. However, there are two issues with that. 1: 2: I get some warnings from MLFlow about modules not being found. It doesn't seem to affect the results, but I have not figured out what the reason is. However, it is related to when it tries to load the model, where it seems to refer to a name similar to the processor:
I have also tried to use mlflows The whole class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
def __init__(self, experiment_name, tracking_uri):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.log_param("run_id", self.run_id)
@property
def save_dir(self):
return "mlruns"
def log_image(self, key, images, step=None):
# Third-party
from PIL import Image
if step is not None:
key = f"{key}_{step}"
# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but is buggy
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)
img = Image.open(temporary_image)
mlflow.log_image(img, f"{key}.png")
def log_model(self, data_module, model):
input_example = self.create_input_example(data_module)
with torch.no_grad():
model_output = model.common_step(input_example)[
0
] # expects batch, returns tuple (prediction, target, pred_std, _)
log_model_input_example = {
name: tensor.cpu().numpy()
for name, tensor in zip(
["init_states", "target_states", "forcing", "target_times"],
input_example,
)
}
signature = infer_signature(
log_model_input_example, model_output.cpu().numpy()
)
mlflow.pytorch.log_model(
model,
"model",
signature=signature,
)
# validate_serving_input(model_uri, validate_example)
def create_input_example(self, data_module):
if data_module.val_dataset is None:
data_module.setup(stage="fit")
data_loader = data_module.train_dataloader()
batch_sample = next(iter(data_loader))
return batch_sample |
neural_lam/train_model.py
Outdated
def log_model(self, data_module, model): | ||
input_example = self.create_input_example(data_module) | ||
|
||
with torch.no_grad(): | ||
model_output = model.common_step(input_example)[ | ||
0 | ||
] # common_step returns tuple (prediction, target, pred_std, _) | ||
|
||
log_model_input_example = { | ||
name: tensor.cpu().numpy() | ||
for name, tensor in zip( | ||
["init_states", "target_states", "forcing", "target_times"], | ||
input_example, | ||
) | ||
} | ||
|
||
signature = infer_signature( | ||
log_model_input_example, model_output.cpu().numpy() | ||
) | ||
|
||
mlflow.pytorch.log_model( | ||
model, | ||
"model", | ||
signature=signature, | ||
) | ||
|
||
def create_input_example(self, data_module): | ||
|
||
if data_module.val_dataset is None: | ||
data_module.setup(stage="fit") | ||
|
||
data_loader = data_module.train_dataloader() | ||
batch_sample = next(iter(data_loader)) | ||
return batch_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_model
and thereby also create_input_example
is not used so they can be removed. However it can be used to log a model if one wishes to do so, but with an input example that is not validated.
I will vote for removing this and revisit in another PR if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed with 821443a
neural_lam/utils.py
Outdated
warnings.warn( | ||
"Only WandbLogger & MLFlowLogger is supported for tracking metrics" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess what happens in this case is that experiment results will only go to stdout? It would be good to clarify that in this warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed with d503048
logger: str = "wandb" | ||
logger_url: str = "" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we actually want the config to contain, vs what should be cmd-line arguments? I would have thought that the choice of logger would be an argparse flag, in a similar way as the plotting choices. My thought process is that logging/plotting does not affect the end product (trained model) whereas all the current options in the config does. But we are not really consistent with this divide either, as there are plenty of argparse options currently that change the model training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it sounds reasonable to have the logging choices as cmd-line arguments given that the plot arguments are already in that category. On the other hand, don't risk to get too many cmd-line arguments? I find it sometimes quite hard to remember the correct names. Either way, I agree that either both plot and logger should be cmd-line or both should be in a config.
@@ -27,6 +27,8 @@ dependencies = [ | |||
"parse>=1.20.2", | |||
"dataclass-wizard<0.31.0", | |||
"mllam-data-prep>=0.5.0", | |||
"mlflow>=2.16.2", | |||
"boto3>=1.35.32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add pynvml as a requirement - this was the reason I did not get any gpu stats in the runs we discussed earlier @khintz
Describe your changes
Add support for mlflow logger by utilising pytorch_lightning.loggers
The native wandb module is replaced with pytorch_lightning wandb logger and introducing pytorch_lightning mlflow logger.
https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/logger.py
This will allow people to choose between wandb and mlflow.
Builds upon #66 although this is not strictly necessary for this change, but I am working with this feature to work with our dataset.
Issue Link
Closes #76
Type of change
Checklist before requesting a review
pull
with--rebase
option if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee