Skip to content

Commit

Permalink
save transform artifacts in every step dir, properly nest inference c…
Browse files Browse the repository at this point in the history
…onfig within train config (#134)
  • Loading branch information
francoishernandez authored Oct 25, 2024
1 parent 5bf241b commit 0ec1088
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 32 deletions.
22 changes: 5 additions & 17 deletions eole/config/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
model_config = get_config_dict()
model_config["arbitrary_types_allowed"] = True # to allow torch.dtype

# TODO: clarify models vs model (model config retrieved from checkpoint)
model_path: str | List[str] = Field(
description="Path to model .pt file(s). "
"Multiple models can be specified for ensemble decoding."
) # some specific (mapping to "models") in legacy code, need to investigate
src: str = Field(description="Source file to decode (one line per sequence).")
tgt: str | None = Field(
default=None,
description="True target sequences, useful for scoring or prefix decoding.",
)
tgt_file_prefix: bool = Field(
default=False, description="Generate predictions using provided tgt as prefix."
)
output: str = Field(
default="pred.txt",
description="Path to output the predictions (each line will be the decoded sequence).",
)
report_align: bool = Field(
default=False, description="Report alignment for each translation."
)
Expand All @@ -148,6 +131,11 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
data_type: str | None = (
"text" # deprecated? hopefully will change with input streams logic
)
chat_template: str | None = None
optional_eos: List[str] | None = Field(
default=[],
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
)

def get_model_path(self):
return self.model_path[0]
Expand Down
20 changes: 16 additions & 4 deletions eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class TrainConfig(
) # not sure this still works
model: ModelConfig | None = None # TypeAdapter handling discrimination directly
training: TrainingConfig | None = Field(default_factory=TrainingConfig)
inference: InferenceConfig | None = Field(default=None)

def get_model_path(self):
return self.training.get_model_path()
Expand Down Expand Up @@ -100,10 +101,21 @@ class PredictConfig(
None # patch for CT2 inference engine (to improve later)
)
model: ModelConfig | None = None
chat_template: str | None = None
optional_eos: List[str] | None = Field(
default=[],
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
model_path: str | List[str] = Field(
description="Path to model .pt file(s). "
"Multiple models can be specified for ensemble decoding."
) # some specific (mapping to "models") in legacy code, need to investigate
src: str = Field(description="Source file to decode (one line per sequence).")
tgt: str | None = Field(
default=None,
description="True target sequences, useful for scoring or prefix decoding.",
)
tgt_file_prefix: bool = Field(
default=False, description="Generate predictions using provided tgt as prefix."
)
output: str = Field(
default="pred.txt",
description="Path to output the predictions (each line will be the decoded sequence).",
)

@model_validator(mode="after")
Expand Down
15 changes: 10 additions & 5 deletions eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def load_checkpoint(model_path):
config_dict = json.loads(os.path.expandvars(f.read()))
# drop data to prevent validation issues
config_dict["data"] = {}
# drop inference to prevent validation issues
if "inference" in config_dict.keys():
config_dict.pop("inference")
if "training" in config_dict.keys():
config_dict["training"]["dummy_load"] = True
else:
Expand Down Expand Up @@ -290,13 +287,18 @@ def _save_config(self):

def _save_transforms_artifacts(self):
if self.transforms is not None:
checkpoint_path = os.path.join(self.model_path, self.step_dir)
for transform_name, transform in self.transforms.items():
transform_save_config = transform._save_artifacts(self.model_path)
transform_save_config, artifacts = transform._save_artifacts(
checkpoint_path
)
setattr(
self.config.transforms_configs,
transform_name,
transform_save_config,
)
for artifact in artifacts:
self._make_symlink(artifact)
# we probably do not need to save transforms artifacts for each checkpoint
# transform._save_artifacts(os.path.join(self.model_path, self.step_dir))

Expand All @@ -323,7 +325,10 @@ def _save(self, step):
)
self._save_optimizer()
self._save_weights(model_state_dict)
logger.info(f"Saving transforms artifacts, if any, to {self.model_path}")
logger.info(
"Saving transforms artifacts, if any, "
f"to {os.path.join(self.model_path, self.step_dir)}"
)
self._save_transforms_artifacts()
logger.info(f"Saving config and vocab to {self.model_path}")
self._save_vocab()
Expand Down
15 changes: 9 additions & 6 deletions eole/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def warm_up(self, vocabs=None):

def _save_artifacts(self, model_path):
save_config = copy.deepcopy(self.config)
artifacts = []
for artifact in self.artifacts:
maybe_artifact = getattr(self, artifact, None)
if maybe_artifact is not None and os.path.exists(maybe_artifact):
Expand All @@ -66,12 +67,14 @@ def _save_artifacts(self, model_path):
shutil.copy(maybe_artifact, model_path)
except shutil.SameFileError:
pass
setattr(
save_config,
artifact,
os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)),
)
return save_config
finally:
artifacts.append(os.path.basename(maybe_artifact))
setattr(
save_config,
artifact,
os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)),
)
return save_config, artifacts

@classmethod
def add_options(cls, parser):
Expand Down

0 comments on commit 0ec1088

Please sign in to comment.