Skip to content

Commit

Permalink
Tie lm_head to embedding weights
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed Jun 7, 2024
1 parent b29c1df commit e1accbb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ class Phi3SmallNonTransformerContainer(LayerContainer):
word_emb: EmbeddingParameter
final_norm_gamma: NormParameter
final_norm_beta: NormParameter
word_unembed: UnembedParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.embed_tokens.weight": ["word_emb.params", "word_unembed.params"],
"model.final_layernorm.weight": "final_norm_gamma.params",
"model.final_layernorm.bias": "final_norm_beta.params",
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,9 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device)
torch.nn.init.xavier_uniform_(word_unembed)

logits = self.unembed(hidden_states,
word_unembed,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm_gamma,
beta=self._non_transformer.final_norm_beta)
Expand Down

0 comments on commit e1accbb

Please sign in to comment.