Skip to content

Commit

Permalink
Rename MistralNeMo2407Config12B to MistralNeMoConfig12B per review's …
Browse files Browse the repository at this point in the history
…suggestion

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Sep 26, 2024
1 parent 30f7104 commit a02dbfc
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
MaskedTokenLossReduction,
MistralConfig7B,
MistralModel,
MistralNeMo2407Config12B,
MistralNeMoConfig12B,
MixtralConfig8x3B,
MixtralConfig8x7B,
MixtralConfig8x22B,
Expand Down Expand Up @@ -111,7 +111,7 @@
"gpt_forward_step",
"MaskedTokenLossReduction",
"MistralConfig7B",
"MistralNeMo2407Config12B",
"MistralNeMoConfig12B",
"MistralModel",
"MixtralConfig8x3B",
"MixtralConfig8x7B",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
LlamaConfig,
LlamaModel,
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMo2407Config12B
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.gpt.model.mixtral import (
MixtralConfig8x3B,
MixtralConfig8x7B,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MistralConfig7B(GPTConfig):


@dataclass
class MistralNeMo2407Config12B(MistralConfig7B):
class MistralNeMoConfig12B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-nemo/
"""
Expand All @@ -75,7 +75,7 @@ class MistralNeMo2407Config12B(MistralConfig7B):


@dataclass
class MistralNeMo2407Config123B(MistralConfig7B):
class MistralNeMoConfig123B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-large-2407/
"""
Expand Down
26 changes: 13 additions & 13 deletions nemo/collections/llm/recipes/mistral_nemo_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,32 @@
from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMo2407Config12B
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.utils.exp_manager import TimingCallback

NAME = "mistral_nemo_base_2407"
NAME = "mistral_nemo_base_12b"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mistral-Nemo-Base-2407 model configuration.
Factory function to create a Mistral-Nemo-Base-12B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Mistral-Nemo-Base-2407 model.
run.Config[pl.LightningModule]: Configuration for the Mistral-Nemo-Base-12B model.
Examples:
CLI usage:
$ nemo llm pretrain model=mistral_nemo_base_2407 ...
$ nemo llm pretrain model=mistral_nemo_base_12b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(MistralModel, config=run.Config(MistralNeMo2407Config12B))
return run.Config(MistralModel, config=run.Config(MistralNeMoConfig12B))


def trainer(
Expand All @@ -65,7 +65,7 @@ def trainer(
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mistral-Nemo-Base-2407 model.
Configure the NeMo Lightning Trainer for Mistral-Nemo-Base-12B model.
This function sets up the distributed training strategy and other training parameters.
Expand All @@ -86,7 +86,7 @@ def trainer(
Examples:
CLI usage:
$ nemo llm pretrain trainer=mistral_nemo_base_2407 ...
$ nemo llm pretrain trainer=mistral_nemo_base_12b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
Expand Down Expand Up @@ -139,7 +139,7 @@ def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mistral-Nemo-Base-2407 model.
Create a pre-training recipe for Mistral-Nemo-Base-12B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Expand All @@ -156,8 +156,8 @@ def pretrain_recipe(
Examples:
CLI usage:
$ nemo llm pretrain --factory mistral_nemo_base_2407
$ nemo llm pretrain --factory "mistral_nemo_base_2407(num_nodes=2, name='my_mistral_pretrain')"
$ nemo llm pretrain --factory mistral_nemo_base_12b
$ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_mistral_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mistral_pretrain", num_nodes=2)
Expand Down Expand Up @@ -187,7 +187,7 @@ def pretrain_recipe(
@run.cli.factory(name=NAME + "_hf")
def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mistral-Nemo-Base-2407 model.
Configure automatic resumption from a Hugging Face checkpoint for Mistral-Nemo-Base-12B model.
This function sets up the configuration to resume training from a pre-trained
Hugging Face model checkpoint.
Expand All @@ -199,7 +199,7 @@ def hf_resume() -> run.Config[nl.AutoResume]:
Note:
This is particularly useful for fine-tuning scenarios where you want to
start from the pre-trained Mistral-Nemo-Base-2407 model.
start from the pre-trained Mistral-Nemo-Base-12B model.
"""
return run.Config(
nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-Nemo-Base-2407")
Expand Down

0 comments on commit a02dbfc

Please sign in to comment.