Skip to content

Commit

Permalink
revert main changes
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Oct 31, 2024
1 parent 6dbf2f7 commit b7fa50c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
name: Evaluation Tests

on:
pull_request:
workflow_dispatch:
schedule:
# Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT.
Expand Down Expand Up @@ -72,7 +73,7 @@ jobs:
iree-runtime \
"numpy<2.0"
- name: Run perplexity test with vmfb
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://0' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json

test_perplexity_torch:
timeout-minutes: 1000
Expand Down
6 changes: 5 additions & 1 deletion sharktank/sharktank/evaluate/perplexity_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
}
logger = logging.getLogger("eval")

logger.setLevel(log_levels["info"])
logger.setLevel(log_levels["debug"])

logger.root.handlers[0].setFormatter(
logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
Expand Down Expand Up @@ -146,6 +146,10 @@ def get_prompts(self):
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
]

test_prompts = [
"Robert Boulter is an English film, television and theatre actor."
]

logger.info(f" num_test_prompts: {len(test_prompts)}")

return test_prompts
Expand Down
26 changes: 19 additions & 7 deletions sharktank/sharktank/evaluate/perplexity_vmfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def get_prompts(self):
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
]

test_prompts = [
"Robert Boulter is an English film, television and theatre actor."
]

self.bs = len(test_prompts)

return test_prompts
Expand Down Expand Up @@ -210,19 +214,23 @@ def prefill_vmfb(self, token_batch, i):
bs=self.bs,
)

print(
"prefill cache", len(self.batch.cache_state), len(self.batch.cache_state[0])
)

seq_block_ids = self.batch.pad_block_ids()
prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.batch.cache_state.to(torch.float16),
)

prefill_logits = torch.tensor(prefill_logits[:, :, :])

tokens = torch.tensor(
self.generator.model.extract_tokens_from_logits(
prefill_logits, seq_lens_batch
prefill_logits, self.batch.seq_lens
)
).unsqueeze(1)
self.batch.add_result_token(tokens)
Expand All @@ -237,17 +245,21 @@ def decode_vmfb(self, token_batch, i):
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
logger.debug(f"{token_batch.tolist()}")

start_positions = self.seq_lens_batch.clone()
self.seq_lens_batch.add_(1)
start_positions = self.batch.seq_lens.clone()
self.batch.seq_lens.add_(1)
self.batch.allocate_seq_block_ids()
seq_block_ids = self.batch.pad_block_ids()

print(
"decode cache", len(self.batch.cache_state), len(self.batch.cache_state[0])
)

decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
start_positions,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.batch.cache_state.to(torch.float16),
)

decode_logits = torch.tensor(decode_logits[:, :, :])
Expand Down
16 changes: 15 additions & 1 deletion sharktank/sharktank/utils/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 8192))
self.end_token = end_token

@property
Expand Down Expand Up @@ -148,6 +148,13 @@ def prefill(self):
attention_mask = model.attention_mask(
model.input_mask(self.seq_lens, self.token_ids.shape[1])
)

print(
"prefill cache load_llm",
len(self.cache_state),
len(self.cache_state[0]),
self.cache_state,
)
seq_block_ids_tensor = self.pad_block_ids()
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
Expand Down Expand Up @@ -183,6 +190,13 @@ def decode(self, token_batch):
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
)
)
print(
"decode cache load_llm",
len(self.cache_state),
len(self.cache_state[0]),
self.cache_state,
)

trace_tensor("decode.token_ids", self.token_ids)
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
Expand Down

0 comments on commit b7fa50c

Please sign in to comment.