Skip to content

Commit

Permalink
Merge branch 'main' into ci-durations
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd authored Nov 27, 2024
2 parents 398185a + ff5e1d7 commit 0a37563
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 18 deletions.
16 changes: 11 additions & 5 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def make_llama_attention_block_theta(
*,
block_idx: int,
head_count: int,
head_count_kv: int,
head_dim: int,
Expand All @@ -21,25 +22,30 @@ def make_llama_attention_block_theta(
return Theta(
{
"attn_q.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.attn_q.weight",
data=make_rand_torch(
(head_count * head_dim, embedding_length), dtype=dtype
)
),
),
"attn_k.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.attn_k.weight",
data=make_rand_torch(
(head_count_kv * head_dim, embedding_length), dtype=dtype
)
),
),
"attn_v.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.attn_v.weight",
data=make_rand_torch(
(head_count_kv * head_dim, embedding_length), dtype=dtype
)
),
),
"attn_output.weight": DefaultPrimitiveTensor(
data=make_rand_torch((embedding_length, embedding_length), dtype=dtype)
name=f"blk.{block_idx}.attn_output.weight",
data=make_rand_torch((embedding_length, embedding_length), dtype=dtype),
),
"attn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((embedding_length), dtype=dtype)
name=f"blk.{block_idx}.attn_norm.weight",
data=make_rand_torch((embedding_length), dtype=dtype),
),
}
)
40 changes: 27 additions & 13 deletions sharktank/sharktank/models/llama/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def make_attention_block_theta(

def make_attention_block_ffn_theta_v2(
*,
block_idx: int,
head_count: int,
head_count_kv: int,
head_dim: int,
Expand All @@ -65,6 +66,7 @@ def make_attention_block_ffn_theta_v2(
dtype: torch.dtype | None = None,
) -> Theta:
attention_theta = make_llama_attention_block_theta(
block_idx=block_idx,
head_count=head_count,
head_count_kv=head_count_kv,
head_dim=head_dim,
Expand All @@ -74,22 +76,26 @@ def make_attention_block_ffn_theta_v2(
ffn_theta = Theta(
{
"ffn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((head_count * head_dim), dtype=dtype)
name=f"blk.{block_idx}.ffn_norm.weight",
data=make_rand_torch((head_count * head_dim), dtype=dtype),
),
"ffn_gate.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_gate.weight",
data=make_rand_torch(
(feed_forward_length, embedding_length), dtype=dtype
)
),
),
"ffn_up.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_up.weight",
data=make_rand_torch(
(feed_forward_length, embedding_length), dtype=dtype
)
),
),
"ffn_down.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_down.weight",
data=make_rand_torch(
(embedding_length, feed_forward_length), dtype=dtype
)
),
),
}
)
Expand All @@ -102,22 +108,26 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta
return Theta(
{
"blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor(
data=make_rand_torch((num_experts, ffn_dim))
name="blk.0.ffn_gate_inp.weight",
data=make_rand_torch((num_experts, ffn_dim)),
),
"blk.0.ffn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
name="blk.0.ffn_norm.weight", data=make_rand_torch((ffn_dim))
),
"blk.0.layer_output_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
name="blk.0.layer_output_norm.weight", data=make_rand_torch((ffn_dim))
),
"blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim))
name="blk.0.layer_output_norm.weight",
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)),
),
"blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim))
name="blk.0.ffn_up_exps.weight",
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)),
),
"blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts))
name="blk.0.ffn_down_exps.weight",
data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)),
),
}
)
Expand All @@ -128,11 +138,13 @@ def make_random_llama_theta(
) -> Theta:
res = {
"token_embd.weight": DefaultPrimitiveTensor(
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype)
name="token_embd.weight",
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype),
)
}
for i in range(config.hp.block_count):
res[f"blk.{i}"] = make_attention_block_ffn_theta_v2(
block_idx=i,
head_count=config.hp.attention_head_count,
head_count_kv=config.hp.attention_head_count_kv,
head_dim=config.hp.attn_head_dim,
Expand All @@ -142,10 +154,12 @@ def make_random_llama_theta(
).tree

res[f"output.weight"] = DefaultPrimitiveTensor(
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype)
name="output.weight",
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype),
)
res[f"output_norm.weight"] = DefaultPrimitiveTensor(
data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype)
name="output_norm.weight",
data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype),
)

return Theta(res)
67 changes: 67 additions & 0 deletions sharktank/sharktank/models/llama/toy_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .testing import make_random_llama_theta

from sharktank.layers.configs import LlamaHParams
from sharktank.models.llama.llama import LlamaModelConfig
from sharktank.types import Dataset

import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--seed", default=12345)
parser.add_argument("-o", "--output", default="/tmp/toy_llama.irpa")


def main():
args = parser.parse_args()
torch.manual_seed(args.seed)

dtype = torch.float16
block_seq_stride = 16
max_blocks = 8
attention_head_count = 8
attn_head_dim = 32
attention_head_count_kv = 4
rope_dimension_count = 32
vocabulary_size = 256

config = LlamaModelConfig(
hp=LlamaHParams(
context_length=block_seq_stride * max_blocks,
embedding_length=attention_head_count * attn_head_dim,
block_count=3,
feed_forward_length=23,
rope_dimension_count=rope_dimension_count,
rope_freq_base=500000.0,
attention_head_count=attention_head_count,
attn_head_dim=attn_head_dim,
attention_layer_norm_rms_epsilon=0.01,
attention_head_count_kv=attention_head_count_kv,
expert_count=0,
expert_used_count=0,
model_arch="llama",
),
block_seq_stride=block_seq_stride,
activation_dtype=dtype,
attention_dtype=dtype,
)

theta = make_random_llama_theta(
config=config,
vocab_size=vocabulary_size,
)

config_dict = config.hp.to_gguf_props()

dataset = Dataset(config_dict, theta)
dataset.save(args.output)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer:
If the data_files= dict is present and explicit tokenizer options are not
set, we will try to infer a tokenizer from the data files.
"""
if args.tokenizer_type == "fake":
return tokenizer.fake_tokenizer()

if args.tokenizer_config_json is not None:
data_files = {"tokenizer_config.json": args.tokenizer_config_json}
else:
Expand Down
18 changes: 18 additions & 0 deletions sharktank/sharktank/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ def _decode(self, tokens: list[list[int]]) -> list[str]:
...


class FakeTokenizer(InferenceTokenizer):
def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
encoded = []
for text in texts:
encoded.append([int(t) for t in text.split(" ")])
return encoded

def _decode(self, tokens: list[list[int]]) -> list[str]:
strings = []
for token in tokens:
strings.append(" ".join([str(t) for t in token]))
return strings


def fake_tokenizer():
return FakeTokenizer()


def load_tokenizer(*posargs, tokenizer_type: str = "transformers", **kwargs):
if tokenizer_type == "transformers":
return _create_transformers_tokenizer(*posargs, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def testExportDecomposed(self):
cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)

theta = make_llama_attention_block_theta(
block_idx=0,
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
Expand Down Expand Up @@ -133,6 +134,7 @@ def testExportNondecomposed(self):
cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)

theta = make_llama_attention_block_theta(
block_idx=0,
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
)

theta = make_llama_attention_block_theta(
block_idx=0,
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
Expand Down

0 comments on commit 0a37563

Please sign in to comment.