[v0.11.0][Perf] Eliminating the zerolike operator through patch (#3632)
### What this PR does / why we need it? There is a zero-like operator before the attention operation in each decoding stage. After analysis, this operator can be eliminated. The purpose of this PR is to remove this operator and improve performance. --------- Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -632,7 +632,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
|
||||
@@ -1226,7 +1226,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
|
||||
@@ -808,7 +808,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
|
||||
@@ -160,3 +160,15 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when adapted vllm version contains the above PR.
|
||||
#
|
||||
# ** File: worker/patch_attention_layer.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.attention.layer.Attention.forward`
|
||||
# Why:
|
||||
# There is a zerolike operator before the attention operation in each decoding stage.
|
||||
# How
|
||||
# Replace this zerolike operator with torch.empty
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm/pull/26680
|
||||
# Future Plan:
|
||||
# Remove this to match the optimization supported in the VLLM version.
|
||||
#
|
||||
|
||||
@@ -28,3 +28,4 @@ import vllm_ascend.patch.worker.patch_weight_loader # noqa
|
||||
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.worker.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa
|
||||
import vllm_ascend.patch.worker.patch_attention_layer # noqa
|
||||
92
vllm_ascend/patch/worker/patch_attention_layer.py
Normal file
92
vllm_ascend/patch/worker/patch_attention_layer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
dtype=output_dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value, self_kv_cache,
|
||||
attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(query, key, value,
|
||||
self.layer_name)
|
||||
|
||||
|
||||
vllm.attention.layer.Attention.forward = forward
|
||||
@@ -350,7 +350,7 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
||||
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
|
||||
@@ -1098,7 +1098,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
@@ -982,7 +982,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
|
||||
Reference in New Issue
Block a user