From 0ec1088e164abe44dbcab6f0cc1aeb1df725b22a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 25 Oct 2024 17:38:10 +0200 Subject: [PATCH] save transform artifacts in every step dir, properly nest inference config within train config (#134) --- eole/config/inference.py | 22 +++++----------------- eole/config/run.py | 20 ++++++++++++++++---- eole/models/model_saver.py | 15 ++++++++++----- eole/transforms/transform.py | 15 +++++++++------ 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/eole/config/inference.py b/eole/config/inference.py index e120ddb3..0bcd2125 100644 --- a/eole/config/inference.py +++ b/eole/config/inference.py @@ -105,23 +105,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig) model_config = get_config_dict() model_config["arbitrary_types_allowed"] = True # to allow torch.dtype - # TODO: clarify models vs model (model config retrieved from checkpoint) - model_path: str | List[str] = Field( - description="Path to model .pt file(s). " - "Multiple models can be specified for ensemble decoding." - ) # some specific (mapping to "models") in legacy code, need to investigate - src: str = Field(description="Source file to decode (one line per sequence).") - tgt: str | None = Field( - default=None, - description="True target sequences, useful for scoring or prefix decoding.", - ) - tgt_file_prefix: bool = Field( - default=False, description="Generate predictions using provided tgt as prefix." - ) - output: str = Field( - default="pred.txt", - description="Path to output the predictions (each line will be the decoded sequence).", - ) report_align: bool = Field( default=False, description="Report alignment for each translation." ) @@ -148,6 +131,11 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig) data_type: str | None = ( "text" # deprecated? hopefully will change with input streams logic ) + chat_template: str | None = None + optional_eos: List[str] | None = Field( + default=[], + description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3", + ) def get_model_path(self): return self.model_path[0] diff --git a/eole/config/run.py b/eole/config/run.py index 4032f202..b2a0dc03 100644 --- a/eole/config/run.py +++ b/eole/config/run.py @@ -34,6 +34,7 @@ class TrainConfig( ) # not sure this still works model: ModelConfig | None = None # TypeAdapter handling discrimination directly training: TrainingConfig | None = Field(default_factory=TrainingConfig) + inference: InferenceConfig | None = Field(default=None) def get_model_path(self): return self.training.get_model_path() @@ -100,10 +101,21 @@ class PredictConfig( None # patch for CT2 inference engine (to improve later) ) model: ModelConfig | None = None - chat_template: str | None = None - optional_eos: List[str] | None = Field( - default=[], - description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3", + model_path: str | List[str] = Field( + description="Path to model .pt file(s). " + "Multiple models can be specified for ensemble decoding." + ) # some specific (mapping to "models") in legacy code, need to investigate + src: str = Field(description="Source file to decode (one line per sequence).") + tgt: str | None = Field( + default=None, + description="True target sequences, useful for scoring or prefix decoding.", + ) + tgt_file_prefix: bool = Field( + default=False, description="Generate predictions using provided tgt as prefix." + ) + output: str = Field( + default="pred.txt", + description="Path to output the predictions (each line will be the decoded sequence).", ) @model_validator(mode="after") diff --git a/eole/models/model_saver.py b/eole/models/model_saver.py index 41239679..5204af50 100644 --- a/eole/models/model_saver.py +++ b/eole/models/model_saver.py @@ -52,9 +52,6 @@ def load_checkpoint(model_path): config_dict = json.loads(os.path.expandvars(f.read())) # drop data to prevent validation issues config_dict["data"] = {} - # drop inference to prevent validation issues - if "inference" in config_dict.keys(): - config_dict.pop("inference") if "training" in config_dict.keys(): config_dict["training"]["dummy_load"] = True else: @@ -290,13 +287,18 @@ def _save_config(self): def _save_transforms_artifacts(self): if self.transforms is not None: + checkpoint_path = os.path.join(self.model_path, self.step_dir) for transform_name, transform in self.transforms.items(): - transform_save_config = transform._save_artifacts(self.model_path) + transform_save_config, artifacts = transform._save_artifacts( + checkpoint_path + ) setattr( self.config.transforms_configs, transform_name, transform_save_config, ) + for artifact in artifacts: + self._make_symlink(artifact) # we probably do not need to save transforms artifacts for each checkpoint # transform._save_artifacts(os.path.join(self.model_path, self.step_dir)) @@ -323,7 +325,10 @@ def _save(self, step): ) self._save_optimizer() self._save_weights(model_state_dict) - logger.info(f"Saving transforms artifacts, if any, to {self.model_path}") + logger.info( + "Saving transforms artifacts, if any, " + f"to {os.path.join(self.model_path, self.step_dir)}" + ) self._save_transforms_artifacts() logger.info(f"Saving config and vocab to {self.model_path}") self._save_vocab() diff --git a/eole/transforms/transform.py b/eole/transforms/transform.py index 90837568..ec2dd0ef 100644 --- a/eole/transforms/transform.py +++ b/eole/transforms/transform.py @@ -58,6 +58,7 @@ def warm_up(self, vocabs=None): def _save_artifacts(self, model_path): save_config = copy.deepcopy(self.config) + artifacts = [] for artifact in self.artifacts: maybe_artifact = getattr(self, artifact, None) if maybe_artifact is not None and os.path.exists(maybe_artifact): @@ -66,12 +67,14 @@ def _save_artifacts(self, model_path): shutil.copy(maybe_artifact, model_path) except shutil.SameFileError: pass - setattr( - save_config, - artifact, - os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)), - ) - return save_config + finally: + artifacts.append(os.path.basename(maybe_artifact)) + setattr( + save_config, + artifact, + os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)), + ) + return save_config, artifacts @classmethod def add_options(cls, parser):