From 2bf4b57b7513ecdee5764d6bcdc6ef7c80d10e7f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 31 Oct 2024 08:40:56 +0000 Subject: [PATCH] Test on PR --- .github/workflows/ci_eval.yaml | 3 ++- sharktank/sharktank/evaluate/perplexity_torch.py | 6 +++++- sharktank/sharktank/evaluate/perplexity_vmfb.py | 12 ++++++++++++ sharktank/sharktank/utils/load_llm.py | 14 ++++++++++++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 94e4ad538..6ce12ee9e 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -7,6 +7,7 @@ name: Evaluation Tests on: + pull_request: workflow_dispatch: schedule: # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. @@ -72,7 +73,7 @@ jobs: iree-runtime \ "numpy<2.0" - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://0' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json test_perplexity_torch: timeout-minutes: 1000 diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..e20ce5bb8 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -36,7 +36,7 @@ } logger = logging.getLogger("eval") -logger.setLevel(log_levels["info"]) +logger.setLevel(log_levels["debug"]) logger.root.handlers[0].setFormatter( logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") @@ -146,6 +146,10 @@ def get_prompts(self): if s != "" and len(s.split()) >= 20 and s.count("=") < 2 ] + test_prompts = [ + "Robert Boulter is an English film, television and theatre actor." + ] + logger.info(f" num_test_prompts: {len(test_prompts)}") return test_prompts diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py index 03244126c..e0981370d 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -183,6 +183,10 @@ def get_prompts(self): if s != "" and len(s.split()) >= 20 and s.count("=") < 2 ] + test_prompts = [ + "Robert Boulter is an English film, television and theatre actor." + ] + self.bs = len(test_prompts) return test_prompts @@ -210,6 +214,10 @@ def prefill_vmfb(self, token_batch, i): bs=self.bs, ) + print( + "prefill cache", len(self.batch.cache_state), len(self.batch.cache_state[0]) + ) + seq_block_ids = self.batch.pad_block_ids() prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( token_batch, @@ -242,6 +250,10 @@ def decode_vmfb(self, token_batch, i): self.batch.allocate_seq_block_ids() seq_block_ids = self.batch.pad_block_ids() + print( + "decode cache", len(self.batch.cache_state), len(self.batch.cache_state[0]) + ) + decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( token_batch, self.batch.seq_lens, diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index acf56eb1b..294b3fff3 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -148,6 +148,13 @@ def prefill(self): attention_mask = model.attention_mask( model.input_mask(self.seq_lens, self.token_ids.shape[1]) ) + + print( + "prefill cache load_llm", + len(self.cache_state), + len(self.cache_state[0]), + self.cache_state, + ) seq_block_ids_tensor = self.pad_block_ids() trace_tensor("prefill.token_ids", self.token_ids) trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) @@ -183,6 +190,13 @@ def decode(self, token_batch): seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, ) ) + print( + "decode cache load_llm", + len(self.cache_state), + len(self.cache_state[0]), + self.cache_state, + ) + trace_tensor("decode.token_ids", self.token_ids) trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)