Skip to content

Commit

Permalink
Rename Phi to Phi3
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed May 21, 2024
1 parent 13260fb commit e79d4b4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'''


class PhiTransformerContainer(LayerContainer):
class Phi3TransformerContainer(LayerContainer):
"""
Transformer layer container for the Phi model.
"""
Expand All @@ -60,7 +60,7 @@ class PhiTransformerContainer(LayerContainer):
}


class PhiNonTransformerContainer(LayerContainer):
class Phi3NonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Phi model.
"""
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/inference/v2/model_implementations/phi3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
from ...modules.interfaces import *
from ...ragged import RaggedBatchWrapper

from .containers import PhiNonTransformerContainer, PhiTransformerContainer
from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer


class PhiInferenceModel(DSTransformerModelBase):
class Phi3InferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for Llama-2 models.
"""

_non_transformer: Optional[PhiNonTransformerContainer]
_non_transformer: Optional[Phi3NonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[PhiTransformerContainer]]
_transformer: Optional[Iterable[Phi3TransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
Expand Down
12 changes: 6 additions & 6 deletions deepspeed/inference/v2/model_implementations/phi3/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@

from ...config_v2 import RaggedInferenceEngineConfig
from ..inference_policy_base import ContainerMap, InferenceV2Policy
from .containers import PhiNonTransformerContainer, PhiTransformerContainer
from .model import PhiInferenceModel
from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer
from .model import Phi3InferenceModel


class Phi3Policy(InferenceV2Policy):

def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> PhiInferenceModel:
return PhiInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)
def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3InferenceModel:
return Phi3InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)

def build_container_map(self) -> ContainerMap:
map = ContainerMap()

transformer_containers = [PhiTransformerContainer(self.model) for _ in range(self.model.num_layers)]
transformer_containers = [Phi3TransformerContainer(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['model.layers'], transformer_containers)

map.set_non_transformer_params(PhiNonTransformerContainer(self.model))
map.set_non_transformer_params(Phi3NonTransformerContainer(self.model))

map.set_unmapped_params(
[f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)])
Expand Down

0 comments on commit e79d4b4

Please sign in to comment.