diff --git a/deepspeed/inference/v2/model_implementations/phi3/containers.py b/deepspeed/inference/v2/model_implementations/phi3/containers.py index 5e51a9131ef9..1cb52a75ae0b 100644 --- a/deepspeed/inference/v2/model_implementations/phi3/containers.py +++ b/deepspeed/inference/v2/model_implementations/phi3/containers.py @@ -39,7 +39,7 @@ ''' -class PhiTransformerContainer(LayerContainer): +class Phi3TransformerContainer(LayerContainer): """ Transformer layer container for the Phi model. """ @@ -60,7 +60,7 @@ class PhiTransformerContainer(LayerContainer): } -class PhiNonTransformerContainer(LayerContainer): +class Phi3NonTransformerContainer(LayerContainer): """ Non-Transformer layer container for the Phi model. """ diff --git a/deepspeed/inference/v2/model_implementations/phi3/model.py b/deepspeed/inference/v2/model_implementations/phi3/model.py index 5aa20a67b8fd..abac8868894e 100644 --- a/deepspeed/inference/v2/model_implementations/phi3/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3/model.py @@ -16,20 +16,20 @@ from ...modules.interfaces import * from ...ragged import RaggedBatchWrapper -from .containers import PhiNonTransformerContainer, PhiTransformerContainer +from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer -class PhiInferenceModel(DSTransformerModelBase): +class Phi3InferenceModel(DSTransformerModelBase): """ Inference model implementation for ragged batching for Llama-2 models. """ - _non_transformer: Optional[PhiNonTransformerContainer] + _non_transformer: Optional[Phi3NonTransformerContainer] """ Embed + unembed container. Specializing the type annotation. """ - _transformer: Optional[Iterable[PhiTransformerContainer]] + _transformer: Optional[Iterable[Phi3TransformerContainer]] """ Per-layer transformer container. Specializing the type annotation. """ diff --git a/deepspeed/inference/v2/model_implementations/phi3/policy.py b/deepspeed/inference/v2/model_implementations/phi3/policy.py index 7db145029286..4ced0272ea47 100644 --- a/deepspeed/inference/v2/model_implementations/phi3/policy.py +++ b/deepspeed/inference/v2/model_implementations/phi3/policy.py @@ -7,23 +7,23 @@ from ...config_v2 import RaggedInferenceEngineConfig from ..inference_policy_base import ContainerMap, InferenceV2Policy -from .containers import PhiNonTransformerContainer, PhiTransformerContainer -from .model import PhiInferenceModel +from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer +from .model import Phi3InferenceModel class Phi3Policy(InferenceV2Policy): - def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> PhiInferenceModel: - return PhiInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3InferenceModel: + return Phi3InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) def build_container_map(self) -> ContainerMap: map = ContainerMap() - transformer_containers = [PhiTransformerContainer(self.model) for _ in range(self.model.num_layers)] + transformer_containers = [Phi3TransformerContainer(self.model) for _ in range(self.model.num_layers)] map.set_transformer_params(['model.layers'], transformer_containers) - map.set_non_transformer_params(PhiNonTransformerContainer(self.model)) + map.set_non_transformer_params(Phi3NonTransformerContainer(self.model)) map.set_unmapped_params( [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)])