[Refactor] 4/N Distinguish the branches based on the applicable scenarios of PA and FIA Ops. (#5081)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

We distinguish the branches based on the applicable scenarios of
pagedAttention and fusedInferAttention, making the code more clear.

At the same time, it is convenient for the subsequent iterations of
sliding_window and sinks and removePA ops after FIA is ready.

Todo:

remove PA ops after FIA is ready
add slidingwindow and ops for gpt_oss
replace FIA with FIA_v2
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-17 23:14:02 +08:00
committed by GitHub
parent 7671ce1bf1
commit 98e6e57622
3 changed files with 117 additions and 154 deletions

View File

@@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_prefill(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
def test_forward_fused_infer_attention(
self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
@@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_decode_only(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention):
def test_forward_paged_attention(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_using_paged_attention):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
query = torch.randn(4, 8 * 64)
key = torch.randn(4, 8 * 64)
value = torch.randn(4, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([4])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_actual_tokens = 4
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
metadata.num_decodes = 4
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_using_paged_attention.return_value = True
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (4, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
@@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase):
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
mock_paged_attention, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
@@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase):
layer = self.layer_no_quant
metadata.num_decodes = 10
metadata.num_prefills = 0
metadata.actual_seq_lengths_q = [10]
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase):
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
mock_fused_infer_attention_score.assert_not_called()
mock_paged_attention.assert_not_called()
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (10, 8, 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):

View File

@@ -386,41 +386,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache = None
self.value_cache = None
def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor) -> torch.Tensor:
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# Normal V1 situation.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.query_start_loc_list[-1]
graph_params = get_graph_params()
@@ -489,7 +459,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def full_graph_attention_with_pa(
def full_graph_pa(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
@@ -550,13 +520,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
return output
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor):
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
@@ -567,7 +537,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked_prefill.
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked prefill.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
@@ -576,12 +554,57 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv
def _forward_fia_slidingwindow(self, query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
key = self.key_cache
value = self.value_cache
if self.key_cache is not None and self.value_cache is not None:
block_size = self.key_cache.shape[1]
key = self.key_cache.flatten(2, 3).contiguous()
value = self.value_cache.flatten(2, 3).contiguous()
output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSH",
block_size=block_size,
pre_tokens=self.sliding_window,
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens)
output = output.view(batch_size, self.num_heads, self.head_size)
return output
def forward_fused_infer_attention(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor):
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
attn_output, num_tokens = self.full_graph_fia(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)):
return self._forward_fia_slidingwindow(query, attn_metadata,
output)
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
@@ -604,83 +627,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
output[:num_tokens] = attn_output[:num_tokens]
return output
def _forward_decode_only_ascend91095(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
batch_size = attn_metadata.query_lens.shape[0]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
)
output[:batch_size] = attn_output[:batch_size]
return output
def _forward_decode_only(
def forward_paged_attention(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_ascend_device_type() == AscendDeviceType.A5:
return self._forward_decode_only_ascend91095(
query, attn_metadata, output)
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
0] == query.size(0):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
key = self.key_cache
value = self.value_cache
if self.key_cache is not None and self.value_cache is not None:
block_size = self.key_cache.shape[1]
key = self.key_cache.flatten(2, 3).contiguous()
value = self.value_cache.flatten(2, 3).contiguous()
output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSH",
block_size=block_size,
pre_tokens=self.sliding_window,
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens)
output = output.view(batch_size, self.num_heads, self.head_size)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output
def _forward_encoder_attention(self, query: torch.Tensor,
@@ -757,23 +721,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
forward_context: ForwardContext = get_forward_context()
if not forward_context.capturing:
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
output = self._forward_decode_only(query, attn_metadata,
output)
else:
output = self._forward_prefill(query, key, value,
attn_metadata, output)
num_tokens = query.shape[0]
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens)
and self.sliding_window is None):
output = self.forward_paged_attention(query, attn_metadata, output)
else:
num_tokens = query.shape[0]
if using_paged_attention(num_tokens):
output = self.full_graph_attention_with_pa(
query, attn_metadata, output)
else:
attn_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
output = self.forward_fused_infer_attention(
query, key, value, attn_metadata, output)
return output

View File

@@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm_ascend.utils import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
get_ascend_device_type)
@lru_cache
@@ -18,8 +19,11 @@ def using_paged_attention(runtime_shape: int) -> bool:
vllm_config = get_current_vllm_config()
if vllm_config.speculative_config is not None:
return False
if get_ascend_device_type() == AscendDeviceType.A5:
return False
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False
return runtime_shape in get_ascend_config().pa_shape_list