From b7fa50c35ab859f5a156d347e9d896efe9c9d0d8 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 31 Oct 2024 08:43:41 +0000 Subject: [PATCH] revert main changes --- .github/workflows/ci_eval.yaml | 3 ++- .../sharktank/evaluate/perplexity_torch.py | 6 ++++- .../sharktank/evaluate/perplexity_vmfb.py | 26 ++++++++++++++----- sharktank/sharktank/utils/load_llm.py | 16 +++++++++++- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 94e4ad538..6ce12ee9e 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -7,6 +7,7 @@ name: Evaluation Tests on: + pull_request: workflow_dispatch: schedule: # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. @@ -72,7 +73,7 @@ jobs: iree-runtime \ "numpy<2.0" - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://0' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json test_perplexity_torch: timeout-minutes: 1000 diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..e20ce5bb8 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -36,7 +36,7 @@ } logger = logging.getLogger("eval") -logger.setLevel(log_levels["info"]) +logger.setLevel(log_levels["debug"]) logger.root.handlers[0].setFormatter( logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") @@ -146,6 +146,10 @@ def get_prompts(self): if s != "" and len(s.split()) >= 20 and s.count("=") < 2 ] + test_prompts = [ + "Robert Boulter is an English film, television and theatre actor." + ] + logger.info(f" num_test_prompts: {len(test_prompts)}") return test_prompts diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py index fedf7c1c9..e0981370d 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -183,6 +183,10 @@ def get_prompts(self): if s != "" and len(s.split()) >= 20 and s.count("=") < 2 ] + test_prompts = [ + "Robert Boulter is an English film, television and theatre actor." + ] + self.bs = len(test_prompts) return test_prompts @@ -210,19 +214,23 @@ def prefill_vmfb(self, token_batch, i): bs=self.bs, ) + print( + "prefill cache", len(self.batch.cache_state), len(self.batch.cache_state[0]) + ) + seq_block_ids = self.batch.pad_block_ids() prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( token_batch, - self.seq_lens_batch, + self.batch.seq_lens, seq_block_ids, - self.batch.cache_state[0].to(torch.float16), + self.batch.cache_state.to(torch.float16), ) prefill_logits = torch.tensor(prefill_logits[:, :, :]) tokens = torch.tensor( self.generator.model.extract_tokens_from_logits( - prefill_logits, seq_lens_batch + prefill_logits, self.batch.seq_lens ) ).unsqueeze(1) self.batch.add_result_token(tokens) @@ -237,17 +245,21 @@ def decode_vmfb(self, token_batch, i): logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") logger.debug(f"{token_batch.tolist()}") - start_positions = self.seq_lens_batch.clone() - self.seq_lens_batch.add_(1) + start_positions = self.batch.seq_lens.clone() + self.batch.seq_lens.add_(1) self.batch.allocate_seq_block_ids() seq_block_ids = self.batch.pad_block_ids() + print( + "decode cache", len(self.batch.cache_state), len(self.batch.cache_state[0]) + ) + decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( token_batch, - self.seq_lens_batch, + self.batch.seq_lens, start_positions, seq_block_ids, - self.batch.cache_state[0].to(torch.float16), + self.batch.cache_state.to(torch.float16), ) decode_logits = torch.tensor(decode_logits[:, :, :]) diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index 558653d9b..294b3fff3 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -31,9 +31,9 @@ def __init__( self.tokenizer = tokenizer if model.cache.is_paged: self.shared_cache_state = model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: self.shared_cache_state = None - self.free_pages = list(range(1, 8192)) self.end_token = end_token @property @@ -148,6 +148,13 @@ def prefill(self): attention_mask = model.attention_mask( model.input_mask(self.seq_lens, self.token_ids.shape[1]) ) + + print( + "prefill cache load_llm", + len(self.cache_state), + len(self.cache_state[0]), + self.cache_state, + ) seq_block_ids_tensor = self.pad_block_ids() trace_tensor("prefill.token_ids", self.token_ids) trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) @@ -183,6 +190,13 @@ def decode(self, token_batch): seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, ) ) + print( + "decode cache load_llm", + len(self.cache_state), + len(self.cache_state[0]), + self.cache_state, + ) + trace_tensor("decode.token_ids", self.token_ids) trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)