[Refactor] Remove redundant attention operator branches. (#4531)

[Refactor] Remove redundant attention operator branches.

Reason:

We replace other attention ops with fused_infer_attention_score expect
decode_only state.
clean code and remove 310P support.

https://github.com/vllm-project/vllm-ascend/pull/4455


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-02 09:13:26 +08:00
committed by GitHub
parent 981a14f8d5
commit b4bf01ead1
3 changed files with 119 additions and 470 deletions

View File

@@ -41,11 +41,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
aligned_16, get_ascend_device_type, nd_to_nz_2d,
nd_to_nz_spec, prefill_context_parallel_enable,
weak_ref_tensors)
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
# isort: off
if prefill_context_parallel_enable():
@@ -83,9 +79,6 @@ class AscendAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if get_ascend_device_type() == AscendDeviceType._310P:
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
16)
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
@@ -351,16 +344,6 @@ class AscendAttentionMetadataBuilder:
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
if get_ascend_device_type() == AscendDeviceType._310P:
if attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
elif attn_state == AscendAttentionState.ChunkedPrefill:
mask_nz = nd_to_nz_spec(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
prefill_metadata = None
decode_metadata = None
@@ -585,9 +568,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor,
num_tokens=0):
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, kv_cache, attn_metadata, output)
return intermediate_output, query.shape[0]
attn_output = self._forward_pcp_dcp(query, key, value, kv_cache,
attn_metadata, output)
return attn_output, query.shape[0]
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
@@ -688,93 +671,58 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def _forward_prefill_no_cache(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
num_tokens=0,
) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
if get_ascend_device_type() == AscendDeviceType._310P:
# align q k v output tensors
query = aligned_16(query)
key = aligned_16(key)
value = aligned_16(value)
output = aligned_16(output)
# do reformat in case of broadcasted tensors
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
assert output is not None
return output[:num_tokens]
def _forward_prefill_cache_hit(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
if block_size == 128:
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=compress_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked_prefill.
else:
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads,
self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
return output
def _forward_decode_only(
@@ -783,10 +731,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_ascend_device_type() == AscendDeviceType._310P:
# seq_lens_tensor needs to be transferred to the device for 310P.
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
0] == query.size(0):
batch_size = attn_metadata.seq_lens.shape[0]
@@ -827,69 +771,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
return output
def _forward_v1_style(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario.
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario.
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables, cu_seqlen_q,
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
self.scale, None, True)
return output
# Use paged attention.
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
if get_ascend_device_type() == AscendDeviceType._310P:
# Do reformat in case of broadcasted tensors.
attn_metadata.attn_mask = \
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
return output
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
q_seqlens: List[int],
k_nomask: torch.Tensor,
@@ -1464,6 +1345,31 @@ class AscendAttentionBackendImpl(AttentionImpl):
)
return key, value
def _forward_encode(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
return output
def forward(
self,
layer: AttentionLayer,
@@ -1494,24 +1400,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
"fused output quantization is not yet supported"
" for AscendAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output
# NOTE: Currently, we have various attention paths for different
# scenarios, and not all of them are in-place operations. Therefore,
# we need to create a separate tensor to hold the attention result.
# In the future, we may consolidate them into fewer paths, which will
# hopefully allow us to use in-place operation by default.
intermediate_output: torch.Tensor
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
@@ -1558,48 +1456,25 @@ class AscendAttentionBackendImpl(AttentionImpl):
forward_context: ForwardContext = get_forward_context()
if not forward_context.capturing:
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, kv_cache, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
attn_output = self._forward_pcp_dcp(query, key, value,
kv_cache, attn_metadata,
output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if self.attn_type == AttentionType.ENCODER_ONLY:
attn_output = self._forward_encode(query, key, value,
attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
output = self._forward_decode_only(query, attn_metadata,
output)
else:
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
output = self._forward_prefill(query, key, value,
attn_metadata, output)
else:
intermediate_output, num_tokens = self.full_graph_attention(
attn_output, num_tokens = self.full_graph_attention(
query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = intermediate_output[:num_tokens]
output[:num_tokens] = attn_output[:num_tokens]
return output