From e1accbb30ff913b35258d86aaa8e3adf0d77534e Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Fri, 7 Jun 2024 21:18:05 +0000 Subject: [PATCH] Tie lm_head to embedding weights --- .../v2/model_implementations/phi3small/containers.py | 3 ++- .../inference/v2/model_implementations/phi3small/model.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/containers.py b/deepspeed/inference/v2/model_implementations/phi3small/containers.py index deb31d311627..fcdc17e6cd4d 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/containers.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/containers.py @@ -79,9 +79,10 @@ class Phi3SmallNonTransformerContainer(LayerContainer): word_emb: EmbeddingParameter final_norm_gamma: NormParameter final_norm_beta: NormParameter + word_unembed: UnembedParameter PARAM_MAPPING = { - "model.embed_tokens.weight": "word_emb.params", + "model.embed_tokens.weight": ["word_emb.params", "word_unembed.params"], "model.final_layernorm.weight": "final_norm_gamma.params", "model.final_layernorm.bias": "final_norm_beta.params", } diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index e8c22a108611..9d5e6a599365 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -175,10 +175,9 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge Performs unembedding of the hidden states to logits. This will only sample the final token of each sequence. """ - word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device) - torch.nn.init.xavier_uniform_(word_unembed) + logits = self.unembed(hidden_states, - word_unembed, + self._non_transformer.word_unembed, ragged_batch_info, gamma=self._non_transformer.final_norm_gamma, beta=self._non_transformer.final_norm_beta)