diff --git a/hydit/inference.py b/hydit/inference.py index c49e207..c7bcc15 100644 --- a/hydit/inference.py +++ b/hydit/inference.py @@ -202,6 +202,7 @@ def __init__(self, args, models_root_path): self.infer_mode = self.args.infer_mode if self.infer_mode in ['fa', 'torch']: + model_dir = self.root / "model" model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt" if not model_path.exists(): raise ValueError(f"model_path not exists: {model_path}")