Skip to content

Commit

Permalink
Merge branch 'habana_main' into adobrzyniewicz/tflops_main
Browse files Browse the repository at this point in the history
  • Loading branch information
adobrzyniewicz-habana committed Aug 13, 2024
2 parents ea649b3 + d291910 commit 396015b
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 73 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: clang-format

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
clang-format:
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: mypy

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down Expand Up @@ -50,4 +50,6 @@ jobs:
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml
6 changes: 3 additions & 3 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: ruff

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/yapf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: yapf

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main
jobs:
yapf:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
39 changes: 18 additions & 21 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@

import vllm.hpu.utils as hpu_utils
from vllm.worker.profiler import Profiler
from vllm.logger import init_logger

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')

logger = init_logger()
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
HPUFusedRMSNorm = FusedRMSNorm
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")

def silu_and_mul(output, input):
d = input.shape[-1] // 2
silu = torch.nn.SiLU().to(input.device)
x, y = torch.split(input, d, dim=-1)
output.copy_(silu(x) * y)
PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')


def fetch_from_cache(cache, blocks, permutations):
Expand Down Expand Up @@ -66,8 +69,7 @@ def paged_attention_v1(query,
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = torch.cat(attn_weights, dim=-1)
attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
Expand Down Expand Up @@ -103,12 +105,9 @@ def paged_attention_v1(query,
return attn_weights.squeeze(-2)


def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(out, x)
return out
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
Expand All @@ -133,13 +132,10 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
htorch.core.mark_step()

for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(
torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
final_hidden_states += w_output * padded_weights[expert_idx]
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
Expand All @@ -166,7 +162,8 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
Expand Down
2 changes: 1 addition & 1 deletion vllm/hpu/rotary_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError:
logger.warning("Could not import HPU FusedRoPE kernel. "
"vLLM will use forward_native implementation of RoPE.")
FusedRoPE = None
FusedRoPE = None
else:
FusedRoPE = None

Expand Down
12 changes: 1 addition & 11 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_hpu

logger = init_logger(__name__)
if is_hpu():
try:
from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as
HPUFusedRMSNorm
)
except ImportError:
logger.warning(
"Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedRMSNorm = None


class RMSNorm(CustomOp):
Expand Down Expand Up @@ -86,6 +75,7 @@ def forward_hpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm.hpu.ops import HPUFusedRMSNorm
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def get_summary_string(self):
return (
f"{format_bytes(self.consumed_device_memory)} of device memory "
f"({format_bytes(self.final_device_memory)}/"
f"({format_bytes(HabanaMemoryProfiler.total_device_memory())} used)"
f"{format_bytes(HabanaMemoryProfiler.total_device_memory())} used)"
f" and {format_bytes(self.consumed_host_memory)} of host memory "
f"({format_bytes(self.final_host_memory)}/"
f"{format_bytes(HabanaMemoryProfiler.total_host_memory())} used)")
Expand Down
102 changes: 83 additions & 19 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __init__(

# Profiler stats
self.profiler_counter_helper = HabanaProfilerCounterHelper()

self._mem_margin: Optional[int] = None
self._setup_buckets()

def load_model(self) -> None:
Expand Down Expand Up @@ -1071,10 +1071,15 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)

def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches,
available_mem):
total_batch_seq = 0.001
total_mem = 0
def warmup_graphs(self,
strategy,
buckets,
is_prompt,
kv_caches,
available_mem,
starting_mem=0,
total_batch_seq=0.001):
total_mem = starting_mem
idx = 0
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
num_candidates = len(buckets)
Expand All @@ -1088,14 +1093,18 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches,
raise NotImplementedError(
f'Unsupported graph allocation strategy: {strategy}')
buckets = list(sorted(buckets, key=ordering))

captured_all = True
for idx, (batch_size, seq_len) in enumerate(buckets):
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len if is_prompt else batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
if mem_estimate >= available_mem:
captured_all = False
continue
graphed_bucket = (batch_size, seq_len, is_prompt)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add((batch_size, seq_len, is_prompt))
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
Expand All @@ -1104,6 +1113,12 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches,
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq

return total_mem, total_batch_seq, captured_all

def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
num_candidates = len(buckets)
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
graphed = list(c[:2] for c in self.graphed_buckets
if c[2] == is_prompt)
msg = (f'{phase} captured:{len(graphed)} '
Expand All @@ -1124,22 +1139,63 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
mem_margin = 1.0 - float(
os.environ.get('VLLM_GRAPH_MEM_MARGIN', '0.02'))
free_mem = \
mem_margin * HabanaMemoryProfiler.current_free_device_memory()
free_mem = align_workers(free_mem, torch.distributed.ReduceOp.MIN)
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = prompt_graph_mem_ratio * free_mem
decode_available_memory = free_mem - prompt_available_memory
prompt_strategy = 'min_tokens'
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
self.warmup_graphs(prompt_strategy, self.prompt_buckets, True,
kv_caches, prompt_available_memory)
self.warmup_graphs(decode_strategy, self.decode_buckets, False,
kv_caches, decode_available_memory)
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more prompt buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not prompt_captured_all \
and decode_captured_all:
mem_post_prompt, _, prompt_captured_all = self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq)

# Not all decode buckets were captured, but all prompt buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(self.prompt_buckets, True,
mem_post_prompt)
self.log_graph_warmup_summary(self.decode_buckets, False,
mem_post_decode)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
Expand All @@ -1154,6 +1210,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

@property
def mem_margin(self) -> Optional[int]:
return self._mem_margin

@mem_margin.setter
def mem_margin(self, value):
self._mem_margin = value


def _maybe_wrap_in_hpu_graph(*args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(
Expand Down
Loading

0 comments on commit 396015b

Please sign in to comment.