diff --git a/open_musiclm/config.py b/open_musiclm/config.py index c356a90..7af684a 100644 --- a/open_musiclm/config.py +++ b/open_musiclm/config.py @@ -188,7 +188,7 @@ def load_model(model, path): path = Path(path) assert path.exists(), f'checkpoint does not exist at {str(path)}' pkg = torch.load(str(path)) - model.load_state_dict(pkg, strict=False) + model.load_state_dict(pkg) class disable_print: def __enter__(self):