Skip to content

Commit

Permalink
add bucket profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Sep 17, 2024
1 parent 690b867 commit 8a90083
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 5 deletions.
56 changes: 56 additions & 0 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,59 @@ def forward(self, input, cache, block_indices, block_offset):

def fetch_from_cache(self, cache, blocks):
return cache.index_select(0, blocks)


def process_run_characteristics(times, block_size, prefill=False):
import statistics
import math
import pandas as pd

def summarize_to_dict(d, block_size, prefill=False):
ret = {}
for k in d:
l, graph = d[k]
bs, seq = k
mean_time = statistics.mean(l)
ctx_len = bs*seq
if not prefill:
ctx_len *= block_size
ret[k] = {"n": len(l), "mean_time": mean_time, 'time_stddev': statistics.stdev(l), 'mean_gen_tput': bs/mean_time, 'mean_in_tput':ctx_len/mean_time, 'hpugraph_captured':int(graph)}
return ret

summary_dict = summarize_to_dict(times, block_size, prefill)
run_df = pd.DataFrame.from_dict(summary_dict,orient='index')
mode = 'decode' if not prefill else 'prefill'
run_df.to_csv(f'vllm_soc_{mode}.csv')
def plot_server_stats_df(df, prefill=False, scale=4):
import matplotlib.pyplot as plt
import seaborn as sns
mode = 'decode' if not prefill else 'prefill'
n = df['n'].iloc[0]
plt.rcParams.update({'font.size': 6 * scale})
fig, axs = plt.subplots(2, 2, tight_layout=True)
fig.suptitle(f"Server operating characteristic ({mode}, n={n})", fontsize=16 * scale)
sz = fig.get_size_inches()
fig.set_size_inches(sz[0]*scale,sz[1]*scale, forward=True)
numel = df['n'].count()
annot_kws={"size": scale * 25 / math.sqrt(numel)}
sns.heatmap(df['mean_time'].unstack()*1000, cmap='RdYlGn_r', ax=axs[0, 0], annot=True, cbar=False, fmt='.3f', square=True, annot_kws=annot_kws)
axs[0, 0].set_title('mean time [ms]')
sns.heatmap(df['time_stddev'].unstack()*1000, cmap='RdYlGn_r', ax=axs[0,1], annot=True, cbar=False, fmt='.2f',square=True, annot_kws=annot_kws, vmin=0)
axs[0, 1].set_title('time stddev [ms]')
tput_mode = 'gen' if not prefill else 'in'
sns.heatmap(df[f'mean_{tput_mode}_tput'].unstack(), cmap='RdYlGn', ax=axs[1, 0], annot=True, cbar=False, fmt='.2f',square=True, annot_kws=annot_kws)
axs[1, 0].set_title(f'mean {tput_mode} tput [tps]')
sns.heatmap(df['hpugraph_captured'].unstack(), cmap='RdYlGn', ax=axs[1, 1], annot=True, cbar=False, square=True, annot_kws=annot_kws, vmin=0, vmax=1)
axs[1, 1].set_title('HPUGraph captured')
for ax in axs.flat:
ax.set(xlabel='num blocks' if not prefill else 'seq len', ylabel='batch size')

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
ax.label_outer()
plt.savefig(f'vllm_soc_{mode}.png')
plt.close(fig)

plot_server_stats_df(run_df, prefill=prefill)

return run_df
29 changes: 25 additions & 4 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,8 @@ def warmup_scenario(self,
seq_len,
is_prompt,
kv_caches,
is_profile_run=False) -> None:
is_profile_run=False,
override_n_runs=None) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
Expand Down Expand Up @@ -1340,6 +1341,8 @@ def warmup_scenario(self,
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_profile_run else 1
if override_n_runs is not None:
times = override_n_runs
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand Down Expand Up @@ -1371,19 +1374,27 @@ def warmup_scenario(self,
]
torch.hpu.synchronize()
profiler = None
fwd_times = []
if is_profile_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
torch.hpu.synchronize()
start = time.perf_counter()
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
self.execute_model(inputs, kv_caches, warmup_mode=False)
torch.hpu.synchronize()
end = time.perf_counter()
elapsed = end - start
fwd_times.append(elapsed)
print(f'[{batch_size}x{seq_len}x{use_graphs}] tput: {batch_size/elapsed:.3f} tps, time: {elapsed*1000:.3f} ms')
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()
return fwd_times, use_graphs

def remove_all_loras(self):
if not self.lora_manager:
Expand Down Expand Up @@ -1428,11 +1439,13 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len):
f"free_mem:{free_mem}")
logger.info(msg)

def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
def warmup_all_buckets(self, buckets, is_prompt, kv_caches, override_n_runs=None):
bucket_times = {}
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
bucket_times[(batch_size, seq_len)] = self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches, override_n_runs=override_n_runs)
return bucket_times

def warmup_graphs(self,
strategy,
Expand Down Expand Up @@ -1636,6 +1649,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
logger.info(msg)
self.profiler.end()

if os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS', 'false').lower() == 'true':
from vllm.hpu.utils import process_run_characteristics
n_runs = int(os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS_N', '5'))
decode_times = self.warmup_all_buckets(self.decode_buckets, False, kv_caches, override_n_runs=n_runs)
process_run_characteristics(decode_times, block_size=self.cache_config.block_size, prefill=False)
prefill_times = self.warmup_all_buckets(self.prompt_buckets, True, kv_caches, override_n_runs=n_runs)
process_run_characteristics(prefill_times, block_size=self.cache_config.block_size, prefill=True)

@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

cache_block_size = self.get_cache_block_size_bytes()
graph_reserved_mem = (float(
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.4'))
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.05'))
if not self.model_config.enforce_eager else 0)
graph_headroom = 1 - graph_reserved_mem
available_hpu_memory = free_hpu_memory * \
Expand Down

0 comments on commit 8a90083

Please sign in to comment.