diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index fb330aadd..e2fc79d78 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -12,6 +12,7 @@ def make_llama_attention_block_theta( *, + block_idx: int, head_count: int, head_count_kv: int, head_dim: int, @@ -21,25 +22,30 @@ def make_llama_attention_block_theta( return Theta( { "attn_q.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_q.weight", data=make_rand_torch( (head_count * head_dim, embedding_length), dtype=dtype - ) + ), ), "attn_k.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_k.weight", data=make_rand_torch( (head_count_kv * head_dim, embedding_length), dtype=dtype - ) + ), ), "attn_v.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_v.weight", data=make_rand_torch( (head_count_kv * head_dim, embedding_length), dtype=dtype - ) + ), ), "attn_output.weight": DefaultPrimitiveTensor( - data=make_rand_torch((embedding_length, embedding_length), dtype=dtype) + name=f"blk.{block_idx}.attn_output.weight", + data=make_rand_torch((embedding_length, embedding_length), dtype=dtype), ), "attn_norm.weight": DefaultPrimitiveTensor( - data=make_rand_torch((embedding_length), dtype=dtype) + name=f"blk.{block_idx}.attn_norm.weight", + data=make_rand_torch((embedding_length), dtype=dtype), ), } ) diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index 079602b28..33424317a 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -57,6 +57,7 @@ def make_attention_block_theta( def make_attention_block_ffn_theta_v2( *, + block_idx: int, head_count: int, head_count_kv: int, head_dim: int, @@ -65,6 +66,7 @@ def make_attention_block_ffn_theta_v2( dtype: torch.dtype | None = None, ) -> Theta: attention_theta = make_llama_attention_block_theta( + block_idx=block_idx, head_count=head_count, head_count_kv=head_count_kv, head_dim=head_dim, @@ -74,22 +76,26 @@ def make_attention_block_ffn_theta_v2( ffn_theta = Theta( { "ffn_norm.weight": DefaultPrimitiveTensor( - data=make_rand_torch((head_count * head_dim), dtype=dtype) + name=f"blk.{block_idx}.ffn_norm.weight", + data=make_rand_torch((head_count * head_dim), dtype=dtype), ), "ffn_gate.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate.weight", data=make_rand_torch( (feed_forward_length, embedding_length), dtype=dtype - ) + ), ), "ffn_up.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_up.weight", data=make_rand_torch( (feed_forward_length, embedding_length), dtype=dtype - ) + ), ), "ffn_down.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_down.weight", data=make_rand_torch( (embedding_length, feed_forward_length), dtype=dtype - ) + ), ), } ) @@ -102,22 +108,26 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta return Theta( { "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor( - data=make_rand_torch((num_experts, ffn_dim)) + name="blk.0.ffn_gate_inp.weight", + data=make_rand_torch((num_experts, ffn_dim)), ), "blk.0.ffn_norm.weight": DefaultPrimitiveTensor( - data=make_rand_torch((ffn_dim)) + name="blk.0.ffn_norm.weight", data=make_rand_torch((ffn_dim)) ), "blk.0.layer_output_norm.weight": DefaultPrimitiveTensor( - data=make_rand_torch((ffn_dim)) + name="blk.0.layer_output_norm.weight", data=make_rand_torch((ffn_dim)) ), "blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) + name="blk.0.layer_output_norm.weight", + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), ), "blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) + name="blk.0.ffn_up_exps.weight", + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), ), "blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)) + name="blk.0.ffn_down_exps.weight", + data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)), ), } ) @@ -128,11 +138,13 @@ def make_random_llama_theta( ) -> Theta: res = { "token_embd.weight": DefaultPrimitiveTensor( - data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + name="token_embd.weight", + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), ) } for i in range(config.hp.block_count): res[f"blk.{i}"] = make_attention_block_ffn_theta_v2( + block_idx=i, head_count=config.hp.attention_head_count, head_count_kv=config.hp.attention_head_count_kv, head_dim=config.hp.attn_head_dim, @@ -142,10 +154,12 @@ def make_random_llama_theta( ).tree res[f"output.weight"] = DefaultPrimitiveTensor( - data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + name="output.weight", + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), ) res[f"output_norm.weight"] = DefaultPrimitiveTensor( - data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype) + name="output_norm.weight", + data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype), ) return Theta(res) diff --git a/sharktank/sharktank/models/llama/toy_llama.py b/sharktank/sharktank/models/llama/toy_llama.py new file mode 100644 index 000000000..09ee00455 --- /dev/null +++ b/sharktank/sharktank/models/llama/toy_llama.py @@ -0,0 +1,67 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .testing import make_random_llama_theta + +from sharktank.layers.configs import LlamaHParams +from sharktank.models.llama.llama import LlamaModelConfig +from sharktank.types import Dataset + +import argparse +import torch + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--seed", default=12345) +parser.add_argument("-o", "--output", default="/tmp/toy_llama.irpa") + + +def main(): + args = parser.parse_args() + torch.manual_seed(args.seed) + + dtype = torch.float16 + block_seq_stride = 16 + max_blocks = 8 + attention_head_count = 8 + attn_head_dim = 32 + attention_head_count_kv = 4 + rope_dimension_count = 32 + vocabulary_size = 256 + + config = LlamaModelConfig( + hp=LlamaHParams( + context_length=block_seq_stride * max_blocks, + embedding_length=attention_head_count * attn_head_dim, + block_count=3, + feed_forward_length=23, + rope_dimension_count=rope_dimension_count, + rope_freq_base=500000.0, + attention_head_count=attention_head_count, + attn_head_dim=attn_head_dim, + attention_layer_norm_rms_epsilon=0.01, + attention_head_count_kv=attention_head_count_kv, + expert_count=0, + expert_used_count=0, + model_arch="llama", + ), + block_seq_stride=block_seq_stride, + activation_dtype=dtype, + attention_dtype=dtype, + ) + + theta = make_random_llama_theta( + config=config, + vocab_size=vocabulary_size, + ) + + config_dict = config.hp.to_gguf_props() + + dataset = Dataset(config_dict, theta) + dataset.save(args.output) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index bc0b3b0b6..99917c2d3 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -138,6 +138,9 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer: If the data_files= dict is present and explicit tokenizer options are not set, we will try to infer a tokenizer from the data files. """ + if args.tokenizer_type == "fake": + return tokenizer.fake_tokenizer() + if args.tokenizer_config_json is not None: data_files = {"tokenizer_config.json": args.tokenizer_config_json} else: diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index b459c706a..f698c7eb0 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -83,6 +83,24 @@ def _decode(self, tokens: list[list[int]]) -> list[str]: ... +class FakeTokenizer(InferenceTokenizer): + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: + encoded = [] + for text in texts: + encoded.append([int(t) for t in text.split(" ")]) + return encoded + + def _decode(self, tokens: list[list[int]]) -> list[str]: + strings = [] + for token in tokens: + strings.append(" ".join([str(t) for t in token])) + return strings + + +def fake_tokenizer(): + return FakeTokenizer() + + def load_tokenizer(*posargs, tokenizer_type: str = "transformers", **kwargs): if tokenizer_type == "transformers": return _create_transformers_tokenizer(*posargs, **kwargs) diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index a55782329..63251c5a9 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -59,6 +59,7 @@ def testExportDecomposed(self): cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( + block_idx=0, head_count=self.attention_head_count, head_count_kv=self.head_count_kv, head_dim=self.attention_head_dim, @@ -133,6 +134,7 @@ def testExportNondecomposed(self): cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( + block_idx=0, head_count=self.attention_head_count, head_count_kv=self.head_count_kv, head_dim=self.attention_head_dim, diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py index c94fd44ab..11a2d90a7 100644 --- a/sharktank/tests/layers/sharded_paged_llama_attention_block.py +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -102,6 +102,7 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[ ) theta = make_llama_attention_block_theta( + block_idx=0, head_count=self.attention_head_count, head_count_kv=self.head_count_kv, head_dim=self.attention_head_dim,