Skip to content

Commit

Permalink
change data format
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Sep 30, 2024
1 parent 00831bc commit 086f97a
Showing 1 changed file with 58 additions and 17 deletions.
75 changes: 58 additions & 17 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def __init__(
if self.calibrate_buckets and self.bucket_cfg_file is None:
vllm_instance_id = get_vllm_instance_id()
self.bucket_cfg_file = f'hpu-buckets-{vllm_instance_id}.yaml'
if self.calibrate_buckets:
msg = f"Calibration results will be saved to {self.bucket_cfg_file}"
logger.info(msg)

Expand Down Expand Up @@ -784,26 +785,70 @@ def bucket_cfg_to_dict(cfg):
self.seen_configs,
columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values(
['is_prefill', 'batch_size', 'seq_or_block'], ascending=False)
data = {}
data['buckets'] = df.to_dict(orient='records')
df['phase'] = df['is_prefill'].apply(lambda x: 'prefill'
if x else 'decode')
data: Dict[str, Any] = {} # type: ignore
#data['buckets'] = df.to_dict(orient='records')
buckets_dict = df.groupby('phase').apply(lambda dd: sorted(
list([list(a) for a in zip(dd['batch_size'], dd['seq_or_block'])]),
reverse=True)).to_dict()
prefill_df = df[df['is_prefill']]
decode_df = df[~df['is_prefill']]

updated_prompt_bs_bucket_cfg = self.prompt_bs_bucket_cfg
updated_prompt_bs_bucket_cfg[0] = int(prefill_df['batch_size'].min())
updated_prompt_bs_bucket_cfg[2] = int(prefill_df['batch_size'].max())
updated_prompt_seq_bucket_cfg = self.prompt_seq_bucket_cfg
updated_prompt_seq_bucket_cfg[0] = int(
prefill_df['seq_or_block'].min())
updated_prompt_seq_bucket_cfg[2] = int(
prefill_df['seq_or_block'].max())
updated_decode_bs_bucket_cfg = self.decode_bs_bucket_cfg
updated_decode_bs_bucket_cfg[0] = int(decode_df['batch_size'].min())
updated_decode_bs_bucket_cfg[2] = int(decode_df['batch_size'].max())
updated_decode_block_bucket_cfg = self.decode_block_bucket_cfg
updated_decode_block_bucket_cfg[0] = int(
decode_df['seq_or_block'].min())
updated_decode_block_bucket_cfg[2] = int(
decode_df['seq_or_block'].max())

data['bucket_cfg'] = {
'prompt_bs_bucket_cfg':
bucket_cfg_to_dict(self.prompt_bs_bucket_cfg),
bucket_cfg_to_dict(updated_prompt_bs_bucket_cfg),
'prompt_seq_bucket_cfg':
bucket_cfg_to_dict(self.prompt_seq_bucket_cfg),
bucket_cfg_to_dict(updated_prompt_seq_bucket_cfg),
'decode_bs_bucket_cfg':
bucket_cfg_to_dict(self.decode_bs_bucket_cfg),
bucket_cfg_to_dict(updated_decode_bs_bucket_cfg),
'decode_block_bucket_cfg':
bucket_cfg_to_dict(self.decode_block_bucket_cfg),
bucket_cfg_to_dict(updated_decode_block_bucket_cfg),
}
data['buckets'] = buckets_dict

class PSS(str):
pass

csv_df = df[['phase', 'batch_size', 'seq_or_block']]
data['buckets_csv'] = PSS(csv_df.to_csv(index=False))

def pss_representer(dumper, data):
style = '|'
tag = u'tag:yaml.org,2002:str'
return dumper.represent_scalar(tag, data, style=style)

yaml.add_representer(PSS, pss_representer, Dumper=yaml.SafeDumper)
import pdb
pdb.set_trace()
with open(bucket_cfg_file, 'w') as outfile:
yaml.dump(data, outfile, default_flow_style=False)
yaml.safe_dump(data,
outfile,
default_flow_style=None,
sort_keys=False)
msg = f"Bucket calibration settings saved to {bucket_cfg_file}"
logger.info(msg)

def deserialize_bucket_settings(self, bucket_cfg_file):
import pandas as pd
import yaml
bucket_cfg = None
prompt_buckets = None
decode_buckets = None
try:
Expand All @@ -813,19 +858,15 @@ def deserialize_bucket_settings(self, bucket_cfg_file):
bucket_cfg = data['bucket_cfg']
# Load pre-generated buckets, if any
if 'buckets' in data:
df = pd.DataFrame.from_dict(data['buckets'])
prompt_buckets = df[df['is_prefill']][[
'batch_size', 'seq_or_block'
]].values.tolist()
prompt_buckets = data['buckets']['prefill']
prompt_buckets = [tuple(b) for b in prompt_buckets]
decode_buckets = df[~df['is_prefill']][[
'batch_size', 'seq_or_block'
]].values.tolist()
decode_buckets = data['buckets']['decode']
decode_buckets = [tuple(b) for b in decode_buckets]
except (FileNotFoundError, IOError, PermissionError):
msg = "Could not open file specified in VLLM_HPU_BUCKET_CFG: "
f"{bucket_cfg_file}. Falling back to default config."
msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: "
f"{bucket_cfg_file}. Falling back to default config.")
logger.error(msg)

return bucket_cfg, (prompt_buckets, decode_buckets)

def _prepare_prompt(
Expand Down

0 comments on commit 086f97a

Please sign in to comment.