[Refactor] Refactor Ascend attention implementation forward (#3714)

### What this PR does / why we need it?
This PR refactors the Ascend attention implementation to align with
vLLM's core interfaces, simplifying the code and improving
maintainability.

### Key Changes:

* **Align with vLLM's Attention Interface**: The `forward` method
signature in `AscendAttentionBackendImpl` now matches the base
`AttentionImpl` in vLLM, removing the custom `trace_flag`.

* **Enable Opaque Attention Operator**: By adding `opaque_attention_op`
to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its
standard `vllm.unified_attention_with_output` operator. This avoids the
need for a custom call path.

*   **Remove Obsolete Code**:
* The custom op `vllm.unified_ascend_attention_with_output` has been
deleted as it is now redundant.
* The `trace_flag` and its associated logic were removed, reducing code
complexity.
* An outdated quantization branch within the attention implementation
was cleaned up.

* **Improve Readability**: Renamed output variables (`output` vs.
`intermediate_output`) and added comments to clarify the in-place nature
of the attention output.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No extra tests needed.

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-25 08:58:35 +08:00
committed by GitHub
parent 0b1da24742
commit 3158742a97
9 changed files with 191 additions and 349 deletions

View File

@@ -32,31 +32,28 @@ from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
# isort: off
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec,
prefill_context_parallel_enable, version_check)
from ..utils import weak_ref_tensors
prefill_context_parallel_enable, version_check,
weak_ref_tensors)
# isort: off
if prefill_context_parallel_enable():
from vllm.distributed import (get_pcp_group,
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort:on
# isort: on
class AscendAttentionBackend(AttentionBackend):
@@ -484,7 +481,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_kv_heads=self.num_kv_heads,
out=output)
assert output is not None
return output[:num_tokens, :, :]
return output[:num_tokens]
def _forward_prefill_cache_hit(
self,
@@ -937,169 +934,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_lse_allgather)
return attn_out
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [key_cache, value_cache]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_int8 = len(
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
ori_output = output
if trace_flag:
torch.ops.vllm.unified_ascend_attention_with_output(
query=query,
key=key,
value=value,
output=output,
layer_name=layer.layer_name)
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
output = layer.quant_method.apply(layer, query, key, value,
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)
else:
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size).fill_(0)
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
if has_decode:
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv, dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx)
key, value = all_kv.split(
[self.head_size, self.head_size], dim=-1)
torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1:
output = self._forward_pcp_dcp(query, key, value,
attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
attn_out = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)
output = attn_out[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
output = self._forward_decode_only(query, attn_metadata,
output)
# Normal V1 situation.
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
output = self._forward_v1_style(query, attn_metadata, output)
# to make in-place change to the output tensor
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
output = output.view(num_tokens, self.num_heads, self.head_size)
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor]) -> torch.Tensor:
output: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
if output is None:
raise ValueError("Output buffer is required")
if has_decode:
decode_query = query[:num_decode_tokens]
output_decode = self._forward_decode_pcp_dcp(
@@ -1131,47 +973,137 @@ class AscendAttentionBackendImpl(AttentionImpl):
output[num_decode_tokens:] = output_prefill
return output
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
def unified_ascend_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for AscendAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
# NOTE: Currently, we have various attention paths for different
# scenarios, and not all of them are in-place operations. Therefore,
# we need to create a separate tensor to hold the attention result.
# In the future, we may consolidate them into fewer paths, which will
# hopefully allow us to use in-place operation by default.
intermediate_output: torch.Tensor
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
direct_register_custom_op(
op_name="unified_ascend_attention_with_output",
op_func=unified_ascend_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key="PrivateUse1",
)
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
if has_decode:
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv, dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx)
key, value = all_kv.split([self.head_size, self.head_size],
dim=-1)
torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
output[:num_tokens] = intermediate_output[:num_tokens]
return output