[Refactor] 4/N Distinguish the branches based on the applicable scenarios of PA and FIA Ops. (#5081)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
We distinguish the branches based on the applicable scenarios of
pagedAttention and fusedInferAttention, making the code more clear.
At the same time, it is convenient for the subsequent iterations of
sliding_window and sinks and removePA ops after FIA is ready.
Todo:
remove PA ops after FIA is ready
add slidingwindow and ops for gpt_oss
replace FIA with FIA_v2
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
def test_forward_fused_infer_attention(
|
||||
self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
@@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_decode_only(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention):
|
||||
def test_forward_paged_attention(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
mock_using_paged_attention):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key = torch.randn(4, 8 * 64)
|
||||
value = torch.randn(4, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([4])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_actual_tokens = 4
|
||||
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
|
||||
metadata.num_decodes = 4
|
||||
metadata.num_prefills = 0
|
||||
layer = self.layer_no_quant
|
||||
mock_using_paged_attention.return_value = True
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
@@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (4, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||
mock_paged_attention, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
layer = self.layer_no_quant
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
metadata.actual_seq_lengths_q = [10]
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
@@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
mock_fused_infer_attention_score.assert_not_called()
|
||||
mock_paged_attention.assert_not_called()
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
|
||||
@@ -386,41 +386,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
|
||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||
graph_params = get_graph_params()
|
||||
@@ -489,7 +459,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output, num_tokens
|
||||
|
||||
def full_graph_attention_with_pa(
|
||||
def full_graph_pa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
@@ -550,13 +520,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output
|
||||
|
||||
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata):
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
@@ -567,7 +537,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# chunked_prefill.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# chunked prefill.
|
||||
else:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
@@ -576,12 +554,57 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
return key, value, block_size, block_table, actual_seq_lengths_kv
|
||||
|
||||
def _forward_fia_slidingwindow(self, query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
key = self.key_cache
|
||||
value = self.value_cache
|
||||
if self.key_cache is not None and self.value_cache is not None:
|
||||
block_size = self.key_cache.shape[1]
|
||||
key = self.key_cache.flatten(2, 3).contiguous()
|
||||
value = self.value_cache.flatten(2, 3).contiguous()
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSH",
|
||||
block_size=block_size,
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
return output
|
||||
|
||||
def forward_fused_infer_attention(self, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
attn_output, num_tokens = self.full_graph_fia(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and self.sliding_window is not None
|
||||
and attn_metadata.seq_lens.shape[0] == query.size(0)):
|
||||
return self._forward_fia_slidingwindow(query, attn_metadata,
|
||||
output)
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
query = query[:num_tokens]
|
||||
# Prepare tensors for attention output
|
||||
# TODO: Refactor this to step-level instead of layer-level
|
||||
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
@@ -604,83 +627,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
def _forward_decode_only_ascend91095(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
)
|
||||
output[:batch_size] = attn_output[:batch_size]
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
def forward_paged_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
return self._forward_decode_only_ascend91095(
|
||||
query, attn_metadata, output)
|
||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||
0] == query.size(0):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
key = self.key_cache
|
||||
value = self.value_cache
|
||||
if self.key_cache is not None and self.value_cache is not None:
|
||||
block_size = self.key_cache.shape[1]
|
||||
key = self.key_cache.flatten(2, 3).contiguous()
|
||||
value = self.value_cache.flatten(2, 3).contiguous()
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSH",
|
||||
block_size=block_size,
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
return self.full_graph_pa(query, attn_metadata, output)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(self, query: torch.Tensor,
|
||||
@@ -757,23 +721,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if not forward_context.capturing:
|
||||
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
output = self._forward_decode_only(query, attn_metadata,
|
||||
output)
|
||||
else:
|
||||
output = self._forward_prefill(query, key, value,
|
||||
attn_metadata, output)
|
||||
num_tokens = query.shape[0]
|
||||
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and using_paged_attention(num_tokens)
|
||||
and self.sliding_window is None):
|
||||
output = self.forward_paged_attention(query, attn_metadata, output)
|
||||
else:
|
||||
num_tokens = query.shape[0]
|
||||
if using_paged_attention(num_tokens):
|
||||
output = self.full_graph_attention_with_pa(
|
||||
query, attn_metadata, output)
|
||||
else:
|
||||
attn_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
output = self.forward_fused_infer_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm_ascend.utils import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
@lru_cache
|
||||
@@ -18,8 +19,11 @@ def using_paged_attention(runtime_shape: int) -> bool:
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.speculative_config is not None:
|
||||
return False
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
return False
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
return False
|
||||
|
||||
return runtime_shape in get_ascend_config().pa_shape_list
|
||||
|
||||
Reference in New Issue
Block a user