-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[model] add support for mixtral moe model #128
Open
936187425
wants to merge
14
commits into
vectorch-ai:main
Choose a base branch
from
936187425:Mixtral
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4342b65
[feat]add mixtral.h
936187425 4e88395
[feat] add some classes in mixtral.h
936187425 7172786
[feat] construct modules from mixtral model except Mixtral moe impl
936187425 c13ebd3
[feat] add the replicated_linear
936187425 f2ffe46
Merge branch 'vectorch-ai:main' into Mixtral
936187425 13da555
[format]
936187425 bdb529c
[refactor] add the wrapper of fused_moe_layer add the fused_moe_kernel
936187425 1a8d37d
Merge branch 'vectorch-ai:main' into Mixtral
936187425 3eb45f0
[feat] add the load_state_dict in fused_moe.cpp
936187425 51f3d56
Merge branch 'vectorch-ai:main' into Mixtral
936187425 d649ac9
[feat] add MixtralBlockExpert using torch version
936187425 046cf63
[format]
936187425 11a4a8a
[bug]Remove third_party/pybind11/ from submodules
936187425 9a5ac5b
[fix]remove the third_party/pybind11
936187425 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
#pragma once | ||
#include <torch/torch.h> | ||
|
||
#include "chat_template/coded_chat_template.h" | ||
#include "layers/activation.h" | ||
#include "layers/attention/attention.h" | ||
#include "layers/attention/handler.h" | ||
#include "layers/embedding.h" | ||
#include "layers/linear.h" | ||
#include "layers/normalization.h" | ||
#include "memory/kv_cache.h" | ||
#include "models/model_args.h" | ||
#include "models/model_registry.h" | ||
#include "models/parameters.h" | ||
|
||
namespace llm::hf { | ||
|
||
class MixtralMoEImpl : public torch::nn::Module { | ||
public: | ||
MixtralMoEImpl(const ModelArgs& args, | ||
const QuantArgs& quant_args, | ||
const ParallelArgs& parallel_args, | ||
const torch::TensorOptions& options) {} | ||
torch::Tensor forward(torch::Tensor x) { return torch::Tensor(); } | ||
|
||
void load_state_dict(const StateDict& state_dict) {} | ||
|
||
void verify_loaded_weights(const std::string& prefix) const {} | ||
}; | ||
TORCH_MODULE(MixtralMoE); | ||
|
||
class MixtralAttentionImpl : public torch::nn::Module { | ||
public: | ||
MixtralAttentionImpl(const ModelArgs& args, | ||
const QuantArgs& quant_args, | ||
const ParallelArgs& parallel_args, | ||
const torch::TensorOptions& options, | ||
AttentionHandler* handler) { | ||
const int32_t world_size = parallel_args.world_size(); | ||
const int64_t hidden_size = args.hidden_size(); | ||
const int64_t n_heads = args.n_heads(); | ||
const int64_t head_dim = args.head_dim(); | ||
const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads); | ||
const int64_t n_local_heads = n_heads / world_size; | ||
const int64_t n_local_kv_heads = n_kv_heads / world_size; | ||
|
||
// size for q, k, v | ||
qkv_sizes_ = {n_local_heads * head_dim, | ||
n_local_kv_heads * head_dim, | ||
n_local_kv_heads * head_dim}; | ||
|
||
// register submodules | ||
qkv_proj_ = register_module( | ||
"qkv_proj", | ||
ColumnParallelLinear(hidden_size, | ||
(n_heads + 2 * n_kv_heads) * head_dim, | ||
/*bias=*/false, | ||
/*gather_output=*/false, | ||
quant_args, | ||
parallel_args, | ||
options)); | ||
|
||
o_proj_ = register_module("o_proj", | ||
RowParallelLinear(n_heads * head_dim, | ||
hidden_size, | ||
/*bias=*/false, | ||
/*input_is_parallelized=*/true, | ||
quant_args, | ||
parallel_args, | ||
options)); | ||
|
||
// initialize attention | ||
atten_ = register_module( | ||
"atten", Attention(n_local_heads, n_local_kv_heads, head_dim, handler)); | ||
} | ||
torch::Tensor forward(torch::Tensor x, | ||
torch::Tensor positions, | ||
KVCache& kv_cache, | ||
const InputParameters& input_params) { | ||
// (num_tokens, dim) x (dim, n_local_heads * head_dim) | ||
// => (num_tokens, n_local_heads * head_dim) | ||
auto qkv = qkv_proj_(x).split(/*split_size=*/qkv_sizes_, /*dim=*/-1); | ||
DCHECK_EQ(qkv.size(), 3); | ||
|
||
// calculate attention, | ||
// output: (num_tokens, n_local_heads*head_dim) | ||
auto output = | ||
atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); | ||
return o_proj_(output); | ||
} | ||
|
||
void load_state_dict(const StateDict& state_dict) { | ||
// call each submodule's load_state_dict function | ||
qkv_proj_->load_state_dict(state_dict, {"q_proj.", "k_proj.", "v_proj."}); | ||
o_proj_->load_state_dict(state_dict.select("o_proj.")); | ||
} | ||
|
||
void verify_loaded_weights(const std::string& prefix) const { | ||
qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); | ||
o_proj_->verify_loaded_weights(prefix + "o_proj."); | ||
} | ||
|
||
private: | ||
// parameter members, must be registered | ||
ColumnParallelLinear qkv_proj_{nullptr}; | ||
|
||
RowParallelLinear o_proj_{nullptr}; | ||
|
||
// module members without parameters | ||
Attention atten_{nullptr}; | ||
|
||
// size for q, k, v | ||
std::vector<int64_t> qkv_sizes_; | ||
}; | ||
TORCH_MODULE(MixtralAttention); | ||
|
||
class MixtralDecoderLayerImpl : public torch::nn::Module { | ||
public: | ||
MixtralDecoderLayerImpl(const ModelArgs& args, | ||
const QuantArgs& quant_args, | ||
const ParallelArgs& parallel_args, | ||
const torch::TensorOptions& options, | ||
AttentionHandler* handler) { | ||
// register submodules | ||
self_attn_ = register_module( | ||
"self_attn", | ||
MixtralAttention(args, quant_args, parallel_args, options, handler)); | ||
|
||
moe_ = register_module( | ||
"moe", MixtralMoE(args, quant_args, parallel_args, options)); | ||
|
||
input_layernorm_ = register_module( | ||
"input_layernorm", | ||
RMSNormResidual(args.hidden_size(), args.rms_norm_eps(), options)); | ||
|
||
post_attention_layernorm_ = register_module( | ||
"post_attention_layernorm", | ||
RMSNormResidual(args.hidden_size(), args.rms_norm_eps(), options)); | ||
} | ||
|
||
torch::Tensor forward(torch::Tensor x, | ||
torch::Tensor positions, | ||
KVCache& kv_cache, | ||
const InputParameters& input_params, | ||
torch::Tensor& residual) { | ||
auto hidden_states = input_layernorm_(x, residual); | ||
|
||
hidden_states = | ||
self_attn_(hidden_states, positions, kv_cache, input_params); | ||
|
||
// fully connected | ||
hidden_states = post_attention_layernorm_(hidden_states, residual); | ||
|
||
return moe_(hidden_states); | ||
} | ||
|
||
void load_state_dict(const StateDict& state_dict) { | ||
self_attn_->load_state_dict(state_dict.select("self_attn.")); | ||
input_layernorm_->load_state_dict(state_dict.select("input_layernorm.")); | ||
post_attention_layernorm_->load_state_dict( | ||
state_dict.select("post_attention_layernorm.")); | ||
moe_->load_state_dict(state_dict.select("block_sparse_moe.")); | ||
} | ||
|
||
void verify_loaded_weights(const std::string& prefix) const { | ||
self_attn_->verify_loaded_weights(prefix + "self_attn."); | ||
input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); | ||
post_attention_layernorm_->verify_loaded_weights( | ||
prefix + "post_attention_layernorm."); | ||
moe_->verify_loaded_weights(prefix + "block_sparse_moe."); | ||
} | ||
|
||
private: | ||
MixtralAttention self_attn_{nullptr}; | ||
|
||
MixtralMoE moe_{nullptr}; | ||
|
||
RMSNormResidual input_layernorm_{nullptr}; | ||
|
||
RMSNormResidual post_attention_layernorm_{nullptr}; | ||
}; | ||
TORCH_MODULE(MixtralDecoderLayer); | ||
|
||
class MixtralModelImpl : public torch::nn::Module { | ||
public: | ||
MixtralModelImpl(const ModelArgs& args, | ||
const QuantArgs& quant_args, | ||
const ParallelArgs& parallel_args, | ||
const torch::TensorOptions& options) { | ||
modelArgs_ = args; | ||
|
||
// TODO: If we have implemented the lora, the vocab_size should be | ||
// processed. | ||
embed_tokens_ = register_module( | ||
"embed_tokens", | ||
ParallelEmbedding( | ||
args.vocab_size(), args.hidden_size(), parallel_args, options)); | ||
|
||
handler_ = AttentionHandler::create_handler_with_rope( | ||
args, /*interleaved=*/false, options); | ||
|
||
blocks_ = register_module("layers", torch::nn::ModuleList()); | ||
layers_.reserve(args.n_layers()); | ||
for (int32_t i = 0; i < args.n_layers(); i++) { | ||
auto block = MixtralDecoderLayer( | ||
args, quant_args, parallel_args, options, handler_.get()); | ||
layers_.push_back(block); | ||
blocks_->push_back(block); | ||
} | ||
|
||
norm_ = register_module( | ||
"norm", | ||
RMSNormResidual(args.hidden_size(), args.rms_norm_eps(), options)); | ||
} | ||
|
||
torch::Tensor forward(torch::Tensor tokens, | ||
torch::Tensor positions, | ||
std::vector<KVCache>& kv_caches, | ||
const InputParameters& input_params) { | ||
auto h = embed_tokens_(tokens); | ||
|
||
torch::Tensor residual; | ||
for (int32_t i = 0; i < modelArgs_.n_layers(); i++) { | ||
auto& layer = layers_[i]; | ||
h = layer(h, positions, kv_caches[i], input_params, residual); | ||
} | ||
|
||
return norm_(h, residual); | ||
} | ||
|
||
void load_state_dict(const StateDict& state_dict) { | ||
embed_tokens_->load_state_dict(state_dict.select("embed_tokens.weight")); | ||
|
||
for (int i = 0; i < layers_.size(); i++) { | ||
layers_[i]->load_state_dict( | ||
state_dict.select("layers." + std::to_string(i) + ".")); | ||
} | ||
norm_->load_state_dict(state_dict.select("norm.weight")); | ||
} | ||
|
||
void verify_loaded_weights(const std::string& prefix) const { | ||
embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.weight"); | ||
|
||
for (int i = 0; i < layers_.size(); i++) { | ||
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + | ||
"."); | ||
} | ||
|
||
norm_->verify_loaded_weights(prefix + "norm.weight"); | ||
} | ||
|
||
private: | ||
ModelArgs modelArgs_; | ||
// parameter members, must be registered | ||
// embedding module | ||
ParallelEmbedding embed_tokens_{nullptr}; | ||
|
||
RMSNormResidual norm_{nullptr}; | ||
|
||
// attention handler | ||
std::unique_ptr<AttentionHandler> handler_{nullptr}; | ||
|
||
torch::nn::ModuleList blocks_{nullptr}; | ||
// hold same data but different type as blocks_ to avoid type cast | ||
std::vector<MixtralDecoderLayer> layers_{nullptr}; | ||
}; | ||
TORCH_MODULE(MixtralModel); | ||
|
||
class MixtralForCausalLMImpl : public torch::nn::Module { | ||
public: | ||
MixtralForCausalLMImpl(const ModelArgs& args, | ||
const QuantArgs& quant_args, | ||
const ParallelArgs& parallel_args, | ||
const torch::TensorOptions& options) { | ||
model_ = register_module( | ||
"model", MixtralModel(args, quant_args, parallel_args, options)); | ||
|
||
// TODO: we can need the lora flag in the future | ||
lm_head_ = register_module("lm_head", | ||
ColumnParallelLinear(args.hidden_size(), | ||
args.vocab_size(), | ||
/*bias=*/false, | ||
/*gather_output=*/true, | ||
parallel_args, | ||
options)); | ||
} | ||
|
||
torch::Tensor forward(const torch::Tensor& tokens, | ||
const torch::Tensor& positions, | ||
std::vector<KVCache>& kv_caches, | ||
const InputParameters& input_params) { | ||
return model_(tokens, positions, kv_caches, input_params); | ||
} | ||
|
||
torch::Tensor logits(const torch::Tensor& hidden_states, | ||
const torch::Tensor& selected_idxes) { | ||
// select tokens if provided | ||
auto h = hidden_states; | ||
if (selected_idxes.defined()) { | ||
h = h.index_select(/*dim=*/0, selected_idxes); | ||
} | ||
return lm_head_(h); | ||
} | ||
|
||
void load_state_dict(const StateDict& state_dict) { | ||
model_->load_state_dict(state_dict.select("model.")); | ||
|
||
lm_head_->load_state_dict(state_dict.select("lm_head.")); | ||
} | ||
|
||
void verify_loaded_weights() const { | ||
model_->verify_loaded_weights("model."); | ||
lm_head_->verify_loaded_weights("lm_head."); | ||
} | ||
|
||
private: | ||
MixtralModel model_{nullptr}; | ||
|
||
ColumnParallelLinear lm_head_{nullptr}; | ||
}; | ||
TORCH_MODULE(MixtralForCausalLM); | ||
|
||
// register the model to make it available | ||
REGISTER_CAUSAL_MODEL(mixtral, MixtralForCausalLM); | ||
|
||
REGISTER_MODEL_ARGS(mixtral, [&] { | ||
// example config from huggingface | ||
// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json | ||
LOAD_ARG_OR(model_type, "model_type", "mixtral"); | ||
LOAD_ARG_OR(bos_token_id, "bos_token_id", 1); | ||
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2); | ||
LOAD_ARG_OR(hidden_size, "hidden_size", 4096); | ||
LOAD_ARG_OR(intermediate_size, "intermediate_size", 14336); | ||
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); | ||
LOAD_ARG_OR(n_heads, "num_attention_heads", 32); | ||
LOAD_ARG_OR(n_experts_per_tok, "num_experts_per_tok", 2); | ||
LOAD_ARG_OR(n_layers, "num_hidden_layers", 32); | ||
LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 8); | ||
LOAD_ARG_OR(n_local_experts, "num_local_experts", 8); | ||
LOAD_ARG_OR(out_router_logits, "output_router_logits", false); | ||
LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-5); | ||
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); | ||
LOAD_ARG_OR(router_aux_loss_coef, "router_aux_loss_coef", 0.02); | ||
LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16"); | ||
LOAD_ARG_OR(vocab_size, "vocab_size", 32000); | ||
|
||
LOAD_ARG_OR(hidden_act, "hidden_activation", "silu"); | ||
|
||
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { | ||
return args->hidden_size() / args->n_heads(); | ||
}); | ||
}); | ||
|
||
} // namespace llm::hf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a heads up. i added support for MQA and GQA, please also include that support in your change. FYI dff774e
you can learn MQA and GQA from this blog: https://iamshobhitagarwal.medium.com/navigating-the-attention-landscape-mha-mqa-and-gqa-decoded-288217d0a7d1