[Refactor] 4/N Distinguish the branches based on the applicable scenarios of PA and FIA Ops. (#5081)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
We distinguish the branches based on the applicable scenarios of
pagedAttention and fusedInferAttention, making the code more clear.
At the same time, it is convenient for the subsequent iterations of
sliding_window and sinks and removePA ops after FIA is ready.
Todo:
remove PA ops after FIA is ready
add slidingwindow and ops for gpt_oss
replace FIA with FIA_v2
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
def test_forward_prefill(self, mock_get_forward_context,
|
def test_forward_fused_infer_attention(
|
||||||
mock_npu_fused_infer_attention_score,
|
self, mock_get_forward_context,
|
||||||
mock_npu_reshape_and_cache):
|
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
||||||
"""Test forward pass in PrefillCacheHit state"""
|
"""Test forward pass in PrefillCacheHit state"""
|
||||||
query = torch.randn(10, 8, 64)
|
query = torch.randn(10, 8, 64)
|
||||||
key = torch.randn(10, 8, 64)
|
key = torch.randn(10, 8, 64)
|
||||||
@@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||||
assert output.shape == (10, 8, 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
def test_forward_decode_only(self, mock_get_forward_context,
|
def test_forward_paged_attention(self, mock_get_forward_context,
|
||||||
mock_npu_reshape_and_cache,
|
mock_npu_reshape_and_cache,
|
||||||
mock_paged_attention):
|
mock_paged_attention,
|
||||||
|
mock_using_paged_attention):
|
||||||
"""Test forward pass in DecodeOnly state"""
|
"""Test forward pass in DecodeOnly state"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(4, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(4, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(4, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
metadata.seq_lens = torch.tensor([10])
|
metadata.seq_lens = torch.tensor([4])
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||||
metadata.num_actual_tokens = 10
|
metadata.num_actual_tokens = 4
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
|
||||||
metadata.num_decodes = 10
|
metadata.num_decodes = 4
|
||||||
metadata.num_prefills = 0
|
metadata.num_prefills = 0
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
mock_using_paged_attention.return_value = True
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
@@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (4, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||||
mock_paged_attention, mock_get_forward_context):
|
mock_paged_attention, mock_get_forward_context):
|
||||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8, 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8, 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8, 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
@@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
metadata.num_decodes = 10
|
metadata.num_decodes = 10
|
||||||
metadata.num_prefills = 0
|
metadata.num_prefills = 0
|
||||||
|
metadata.actual_seq_lengths_q = [10]
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
@@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_not_called()
|
||||||
mock_fused_infer_attention_score.assert_not_called()
|
mock_fused_infer_attention_score.assert_called_once()
|
||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
def test_forward_raise_error(self, mock_paged_attention):
|
def test_forward_raise_error(self, mock_paged_attention):
|
||||||
|
|||||||
@@ -386,41 +386,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
|
|
||||||
def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor,
|
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||||
attn_metadata: AscendMetadata,
|
output: torch.Tensor) -> torch.Tensor:
|
||||||
output: torch.Tensor) -> torch.Tensor:
|
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
= self._get_fia_params(key, value, attn_metadata)
|
||||||
block_size = 128
|
|
||||||
block_table = None
|
|
||||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
|
||||||
elif attn_metadata.attn_state == \
|
|
||||||
AscendAttentionState.PrefillCacheHit:
|
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
|
||||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
key = self.key_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
value = self.value_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
key = self.key_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
value = self.value_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
block_table = attn_metadata.block_tables
|
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
||||||
# Normal V1 situation.
|
|
||||||
else:
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
key = self.key_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
value = self.value_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
block_table = attn_metadata.block_tables
|
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
||||||
|
|
||||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
@@ -489,7 +459,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
return output, num_tokens
|
return output, num_tokens
|
||||||
|
|
||||||
def full_graph_attention_with_pa(
|
def full_graph_pa(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
@@ -550,13 +520,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
|
||||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata):
|
||||||
output: torch.Tensor):
|
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
block_size = 128
|
block_size = 128
|
||||||
block_table = None
|
block_table = None
|
||||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||||
elif attn_metadata.attn_state == \
|
elif attn_metadata.attn_state == \
|
||||||
AscendAttentionState.PrefillCacheHit:
|
AscendAttentionState.PrefillCacheHit:
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
@@ -567,7 +537,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
value = self.value_cache.view( # type: ignore
|
value = self.value_cache.view( # type: ignore
|
||||||
num_block, block_size, -1)
|
num_block, block_size, -1)
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
# chunked_prefill.
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
|
key = self.key_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
value = self.value_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
block_table = attn_metadata.block_tables
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
# chunked prefill.
|
||||||
else:
|
else:
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
key = self.key_cache.view( # type: ignore
|
key = self.key_cache.view( # type: ignore
|
||||||
@@ -576,12 +554,57 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
num_block, block_size, -1)
|
num_block, block_size, -1)
|
||||||
block_table = attn_metadata.block_tables
|
block_table = attn_metadata.block_tables
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
return key, value, block_size, block_table, actual_seq_lengths_kv
|
||||||
|
|
||||||
|
def _forward_fia_slidingwindow(self, query: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: torch.Tensor):
|
||||||
|
batch_size = attn_metadata.seq_lens.shape[0]
|
||||||
|
block_size = 128
|
||||||
|
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||||
|
key = self.key_cache
|
||||||
|
value = self.value_cache
|
||||||
|
if self.key_cache is not None and self.value_cache is not None:
|
||||||
|
block_size = self.key_cache.shape[1]
|
||||||
|
key = self.key_cache.flatten(2, 3).contiguous()
|
||||||
|
value = self.value_cache.flatten(2, 3).contiguous()
|
||||||
|
|
||||||
|
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
input_layout="BSH",
|
||||||
|
block_size=block_size,
|
||||||
|
pre_tokens=self.sliding_window,
|
||||||
|
scale=self.scale,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||||
|
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||||
|
|
||||||
|
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_fused_infer_attention(self, query: torch.Tensor,
|
||||||
|
key: torch.Tensor, value: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: torch.Tensor):
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
if forward_context.capturing:
|
||||||
|
attn_output, num_tokens = self.full_graph_fia(
|
||||||
|
query, key, value, attn_metadata, output)
|
||||||
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
|
return output
|
||||||
|
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||||
|
and self.sliding_window is not None
|
||||||
|
and attn_metadata.seq_lens.shape[0] == query.size(0)):
|
||||||
|
return self._forward_fia_slidingwindow(query, attn_metadata,
|
||||||
|
output)
|
||||||
|
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||||
|
= self._get_fia_params(key, value, attn_metadata)
|
||||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||||
query = query[:num_tokens]
|
query = query[:num_tokens]
|
||||||
# Prepare tensors for attention output
|
|
||||||
# TODO: Refactor this to step-level instead of layer-level
|
|
||||||
|
|
||||||
# Get workspace from cache or calculate it if not present.
|
# Get workspace from cache or calculate it if not present.
|
||||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -604,83 +627,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output[:num_tokens] = attn_output[:num_tokens]
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_decode_only_ascend91095(
|
def forward_paged_attention(
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
attn_metadata: AscendMetadata,
|
|
||||||
output: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
key = self.key_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
value = self.value_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
||||||
|
|
||||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
input_layout="TND",
|
|
||||||
block_size=block_size,
|
|
||||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
||||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
)
|
|
||||||
output[:batch_size] = attn_output[:batch_size]
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _forward_decode_only(
|
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
forward_context: ForwardContext = get_forward_context()
|
||||||
return self._forward_decode_only_ascend91095(
|
if forward_context.capturing:
|
||||||
query, attn_metadata, output)
|
return self.full_graph_pa(query, attn_metadata, output)
|
||||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
torch_npu._npu_paged_attention(query=query,
|
||||||
0] == query.size(0):
|
key_cache=self.key_cache,
|
||||||
batch_size = attn_metadata.seq_lens.shape[0]
|
value_cache=self.value_cache,
|
||||||
block_size = 128
|
num_kv_heads=self.num_kv_heads,
|
||||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
num_heads=self.num_heads,
|
||||||
key = self.key_cache
|
scale_value=self.scale,
|
||||||
value = self.value_cache
|
block_table=attn_metadata.block_tables,
|
||||||
if self.key_cache is not None and self.value_cache is not None:
|
context_lens=attn_metadata.seq_lens,
|
||||||
block_size = self.key_cache.shape[1]
|
out=output)
|
||||||
key = self.key_cache.flatten(2, 3).contiguous()
|
|
||||||
value = self.value_cache.flatten(2, 3).contiguous()
|
|
||||||
|
|
||||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
input_layout="BSH",
|
|
||||||
block_size=block_size,
|
|
||||||
pre_tokens=self.sliding_window,
|
|
||||||
scale=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
|
||||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
|
||||||
|
|
||||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
|
||||||
else:
|
|
||||||
torch_npu._npu_paged_attention(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_encoder_attention(self, query: torch.Tensor,
|
def _forward_encoder_attention(self, query: torch.Tensor,
|
||||||
@@ -757,23 +721,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
):
|
):
|
||||||
forward_context: ForwardContext = get_forward_context()
|
num_tokens = query.shape[0]
|
||||||
if not forward_context.capturing:
|
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||||
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
and using_paged_attention(num_tokens)
|
||||||
output = self._forward_decode_only(query, attn_metadata,
|
and self.sliding_window is None):
|
||||||
output)
|
output = self.forward_paged_attention(query, attn_metadata, output)
|
||||||
else:
|
|
||||||
output = self._forward_prefill(query, key, value,
|
|
||||||
attn_metadata, output)
|
|
||||||
else:
|
else:
|
||||||
num_tokens = query.shape[0]
|
output = self.forward_fused_infer_attention(
|
||||||
if using_paged_attention(num_tokens):
|
query, key, value, attn_metadata, output)
|
||||||
output = self.full_graph_attention_with_pa(
|
|
||||||
query, attn_metadata, output)
|
|
||||||
else:
|
|
||||||
attn_output, num_tokens = self.full_graph_attention(
|
|
||||||
query, key, value, attn_metadata, output)
|
|
||||||
output[:num_tokens] = attn_output[:num_tokens]
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
is_v1_kv_transfer_group)
|
is_v1_kv_transfer_group)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.utils import get_ascend_config
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||||
|
get_ascend_device_type)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
@@ -18,8 +19,11 @@ def using_paged_attention(runtime_shape: int) -> bool:
|
|||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
if vllm_config.speculative_config is not None:
|
if vllm_config.speculative_config is not None:
|
||||||
return False
|
return False
|
||||||
|
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||||
|
return False
|
||||||
from vllm.config.compilation import CUDAGraphMode
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
|
||||||
|
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return runtime_shape in get_ascend_config().pa_shape_list
|
return runtime_shape in get_ascend_config().pa_shape_list
|
||||||
|
|||||||
Reference in New Issue
Block a user