[Refactor] Remove redundant attention operator branches. (#4531)
[Refactor] Remove redundant attention operator branches. Reason: We replace other attention ops with fused_infer_attention_score expect decode_only state. clean code and remove 310P support. https://github.com/vllm-project/vllm-ascend/pull/4455 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -25,12 +25,6 @@ class TestAscendAttentionBackend(TestBase):
|
|||||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||||
AscendAttentionMetadataBuilder)
|
AscendAttentionMetadataBuilder)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
|
|
||||||
return_value=AscendDeviceType._310P)
|
|
||||||
def test_get_kv_cache_shape_310p(self, mock_soc_version):
|
|
||||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
|
||||||
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
|
|
||||||
|
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||||
return_value=AscendDeviceType._910_93)
|
return_value=AscendDeviceType._910_93)
|
||||||
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
|
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
|
||||||
@@ -95,76 +89,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
|
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
|
||||||
@patch('torch_npu.npu_format_cast')
|
|
||||||
@patch('vllm_ascend.utils.nd_to_nz_2d')
|
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
||||||
return_value=AscendDeviceType._310P)
|
|
||||||
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
|
|
||||||
mock_npu_format_cast,
|
|
||||||
mock_ascend_metadata):
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
||||||
query_start_loc=torch.tensor([0, 3, 7]),
|
|
||||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
|
||||||
seq_lens_cpu=torch.tensor([5, 6]),
|
|
||||||
num_reqs=2,
|
|
||||||
num_actual_tokens=10,
|
|
||||||
max_query_len=5,
|
|
||||||
decode_token_per_req=torch.tensor([1, 1]),
|
|
||||||
block_table_tensor=torch.zeros((10, 10)),
|
|
||||||
slot_mapping=torch.tensor(range(20)),
|
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
|
||||||
positions=torch.tensor([10, 10]),
|
|
||||||
attn_mask=torch.ones((10, 10)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.PrefillNoCache,
|
|
||||||
num_computed_tokens_cpu=None,
|
|
||||||
seq_lens=None)
|
|
||||||
|
|
||||||
mock_nz_tensor = MagicMock()
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
|
||||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
|
||||||
|
|
||||||
self.builder.build(1, common_attn_metadata, mock_model)
|
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
|
||||||
@patch('torch_npu.npu_format_cast')
|
|
||||||
@patch('vllm_ascend.utils.nd_to_nz_spec')
|
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
||||||
return_value=AscendDeviceType._310P)
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
|
|
||||||
def test_build_chunked_prefill(self, mock_ascend_attention_state,
|
|
||||||
mock_soc_version, mock_nd_to_nz_spec,
|
|
||||||
mock_npu_format_cast, mock_ascend_metadata):
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
||||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
|
||||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
|
||||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
|
||||||
num_reqs=3,
|
|
||||||
num_actual_tokens=15,
|
|
||||||
max_query_len=6,
|
|
||||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
|
||||||
block_table_tensor=torch.zeros((10, 10)),
|
|
||||||
slot_mapping=torch.tensor(range(20)),
|
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
|
||||||
positions=torch.tensor([10, 10]),
|
|
||||||
attn_mask=torch.ones((15, 15)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
|
||||||
num_computed_tokens_cpu=None,
|
|
||||||
seq_lens=None)
|
|
||||||
|
|
||||||
mock_ascend_attention_state = MagicMock()
|
|
||||||
mock_ascend_attention_state.PrefillNoCache = 0
|
|
||||||
|
|
||||||
mock_nz_tensor = MagicMock()
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
|
||||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
|
||||||
|
|
||||||
self.builder.build(1, common_attn_metadata, mock_model)
|
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||||
return_value=AscendDeviceType._910_93)
|
return_value=AscendDeviceType._910_93)
|
||||||
@@ -286,73 +210,40 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_flash_attention')
|
|
||||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
|
||||||
mock_reshape_cache,
|
|
||||||
mock_get_forward_context):
|
|
||||||
"""Test forward pass in PrefillNoCache state"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = 10
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|
||||||
mock_reshape_cache.assert_called_once()
|
|
||||||
mock_flash_attention.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
def test_forward_prefill(self, mock_get_forward_context,
|
||||||
mock_npu_fused_infer_attention_score,
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_npu_reshape_and_cache):
|
mock_npu_reshape_and_cache):
|
||||||
"""Test forward pass in PrefillCacheHit state"""
|
"""Test forward pass in PrefillCacheHit state"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8, 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8, 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8, 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
metadata.seq_lens = torch.tensor([10])
|
metadata.seq_lens = torch.tensor([10])
|
||||||
|
metadata.actual_seq_lengths_q = [10]
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||||
metadata.num_actual_tokens = 10
|
metadata.num_actual_tokens = 10
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
metadata.num_decode_tokens = 0
|
||||||
metadata.num_decodes = 0
|
metadata.num_decodes = 0
|
||||||
metadata.num_prefills = 10
|
metadata.num_prefills = 10
|
||||||
|
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
|
||||||
torch.ones(
|
10, 8, 64), torch.ones(10, 8, 64))
|
||||||
10, 8, 64))
|
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@@ -454,119 +345,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
||||||
return_value=AscendDeviceType._910_93)
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
|
||||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
|
||||||
mock_npu_reshape_and_cache,
|
|
||||||
mock_soc_version, mock_get_forward_context):
|
|
||||||
"""Test forward pass when head_size is 192"""
|
|
||||||
|
|
||||||
self.impl.head_size = 192
|
|
||||||
query = torch.randn(10, 8 * 192)
|
|
||||||
key = torch.randn(10, 8 * 192)
|
|
||||||
value = torch.randn(10, 8 * 192)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.query_lens = torch.tensor([10])
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 10
|
|
||||||
metadata.num_prefills = 0
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
mock_vanilla_prefill.return_value = MagicMock()
|
|
||||||
|
|
||||||
output = self.impl_192.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|
||||||
mock_vanilla_prefill.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 192)
|
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
|
||||||
mock_npu_fused_infer_attention_score,
|
|
||||||
mock_get_forward_context):
|
|
||||||
"""Test forward pass in normal V1 situation"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.query_lens = torch.tensor([10])
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = 10
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
|
||||||
torch.ones(
|
|
||||||
10, 8, 64))
|
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
@patch('torch_npu.npu_format_cast')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
||||||
return_value=AscendDeviceType._310P)
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
def test_forward_310p_device(self, mock_get_forward_context,
|
|
||||||
mock_soc_version,
|
|
||||||
mock_npu_fused_infer_attention_score,
|
|
||||||
mock_npu_reshape_and_cache,
|
|
||||||
mock_npu_format_cast):
|
|
||||||
"""Test forward pass on 310P device"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.query_lens = torch.tensor([10])
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = 10
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
|
|
||||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
|
||||||
torch.ones(
|
|
||||||
10, 8, 64))
|
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
def test_forward_raise_error(self, mock_paged_attention):
|
def test_forward_raise_error(self, mock_paged_attention):
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
|
|||||||
@@ -41,11 +41,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
|
||||||
aligned_16, get_ascend_device_type, nd_to_nz_2d,
|
|
||||||
nd_to_nz_spec, prefill_context_parallel_enable,
|
|
||||||
weak_ref_tensors)
|
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
if prefill_context_parallel_enable():
|
if prefill_context_parallel_enable():
|
||||||
@@ -83,9 +79,6 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
|
||||||
16)
|
|
||||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -351,16 +344,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
if attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
mask_nz = nd_to_nz_2d(attn_mask)
|
|
||||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
elif attn_state == AscendAttentionState.ChunkedPrefill:
|
|
||||||
mask_nz = nd_to_nz_spec(attn_mask)
|
|
||||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
@@ -585,9 +568,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens=0):
|
num_tokens=0):
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
intermediate_output = self._forward_pcp_dcp(
|
attn_output = self._forward_pcp_dcp(query, key, value, kv_cache,
|
||||||
query, key, value, kv_cache, attn_metadata, output)
|
attn_metadata, output)
|
||||||
return intermediate_output, query.shape[0]
|
return attn_output, query.shape[0]
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
block_size = 128
|
block_size = 128
|
||||||
block_table = None
|
block_table = None
|
||||||
@@ -688,93 +671,58 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
return output, num_tokens
|
return output, num_tokens
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
self,
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||||
query: torch.Tensor,
|
output: torch.Tensor):
|
||||||
key: torch.Tensor,
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
value: torch.Tensor,
|
block_size = 128
|
||||||
attn_metadata: AscendMetadata,
|
block_table = None
|
||||||
output: Optional[torch.Tensor] = None,
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||||
num_tokens=0,
|
elif attn_metadata.attn_state == \
|
||||||
) -> torch.Tensor:
|
AscendAttentionState.PrefillCacheHit:
|
||||||
assert attn_metadata is not None
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
assert attn_metadata.attn_mask is not None
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
mask = attn_metadata.attn_mask
|
|
||||||
|
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
# align q k v output tensors
|
|
||||||
query = aligned_16(query)
|
|
||||||
key = aligned_16(key)
|
|
||||||
value = aligned_16(value)
|
|
||||||
output = aligned_16(output)
|
|
||||||
# do reformat in case of broadcasted tensors
|
|
||||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
|
||||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
torch_npu._npu_flash_attention(query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seq_len=attn_metadata.seq_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
out=output)
|
|
||||||
assert output is not None
|
|
||||||
return output[:num_tokens]
|
|
||||||
|
|
||||||
def _forward_prefill_cache_hit(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
attn_metadata: AscendMetadata,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert attn_metadata is not None
|
|
||||||
assert attn_metadata.attn_mask is not None
|
|
||||||
|
|
||||||
compress_mask = attn_metadata.attn_mask
|
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
|
||||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
|
|
||||||
if block_size == 128:
|
|
||||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
|
||||||
# be utilized in a wider range in upcoming versions.
|
|
||||||
key = self.key_cache.view( # type: ignore
|
key = self.key_cache.view( # type: ignore
|
||||||
num_block, block_size, -1)
|
num_block, block_size, -1)
|
||||||
value = self.value_cache.view( # type: ignore
|
value = self.value_cache.view( # type: ignore
|
||||||
num_block, block_size, -1)
|
num_block, block_size, -1)
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
# chunked_prefill.
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
atten_mask=compress_mask,
|
|
||||||
block_table=block_table,
|
|
||||||
input_layout="TND",
|
|
||||||
block_size=block_size,
|
|
||||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
||||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=3,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
torch_npu._npu_flash_attention_qlens(
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
query=query,
|
key = self.key_cache.view( # type: ignore
|
||||||
key_cache=self.key_cache,
|
num_block, block_size, -1)
|
||||||
value_cache=self.value_cache,
|
value = self.value_cache.view( # type: ignore
|
||||||
block_table=block_table,
|
num_block, block_size, -1)
|
||||||
mask=compress_mask,
|
block_table = attn_metadata.block_tables
|
||||||
seq_len=attn_metadata.query_lens,
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||||
num_heads=self.num_heads,
|
query = query[:num_tokens]
|
||||||
scale_value=self.scale,
|
# Prepare tensors for attention output
|
||||||
out=output)
|
# TODO: Refactor this to step-level instead of layer-level
|
||||||
|
|
||||||
|
# Get workspace from cache or calculate it if not present.
|
||||||
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
atten_mask=attn_metadata.attn_mask,
|
||||||
|
block_table=block_table,
|
||||||
|
input_layout="TND",
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale=self.scale,
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(num_tokens, self.num_heads,
|
||||||
|
self.head_size)
|
||||||
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_decode_only(
|
def _forward_decode_only(
|
||||||
@@ -783,10 +731,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
|
||||||
attn_metadata.seq_lens = \
|
|
||||||
attn_metadata.seq_lens.to(device=query.device)
|
|
||||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||||
0] == query.size(0):
|
0] == query.size(0):
|
||||||
batch_size = attn_metadata.seq_lens.shape[0]
|
batch_size = attn_metadata.seq_lens.shape[0]
|
||||||
@@ -827,69 +771,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
out=output)
|
out=output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_v1_style(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
attn_metadata: AscendMetadata,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Use chunked prefill for head size 192 scenario, like deepseek
|
|
||||||
# paged_attention_splitfuse maybe crash at such scenario.
|
|
||||||
# TODO: vanilla path will be removed after the kernel support
|
|
||||||
# head_size 192 scenario.
|
|
||||||
if self.head_size == 192:
|
|
||||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
|
||||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
|
||||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
|
|
||||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
|
|
||||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
|
||||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
|
||||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
|
||||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
|
||||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
|
||||||
self.value_cache,
|
|
||||||
attn_metadata.block_tables, cu_seqlen_q,
|
|
||||||
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
|
|
||||||
self.scale, None, True)
|
|
||||||
return output
|
|
||||||
|
|
||||||
# Use paged attention.
|
|
||||||
assert attn_metadata is not None
|
|
||||||
assert attn_metadata.attn_mask is not None
|
|
||||||
|
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
# Do reformat in case of broadcasted tensors.
|
|
||||||
attn_metadata.attn_mask = \
|
|
||||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
attn_metadata.seq_lens = \
|
|
||||||
attn_metadata.seq_lens.to(device=query.device)
|
|
||||||
|
|
||||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
|
||||||
# be utilized in a wider range in upcoming versions.
|
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
||||||
key = self.key_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
value = self.value_cache.view( # type: ignore
|
|
||||||
num_block, block_size, -1)
|
|
||||||
|
|
||||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
atten_mask=attn_metadata.attn_mask,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
input_layout="TND",
|
|
||||||
block_size=block_size,
|
|
||||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
||||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=3,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||||
q_seqlens: List[int],
|
q_seqlens: List[int],
|
||||||
k_nomask: torch.Tensor,
|
k_nomask: torch.Tensor,
|
||||||
@@ -1464,6 +1345,31 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
def _forward_encode(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||||
|
output = torch_npu.npu_fusion_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
input_layout="TND",
|
||||||
|
scale=self.scale,
|
||||||
|
sparse_mode=4,
|
||||||
|
atten_mask=attn_metadata.attn_mask,
|
||||||
|
pre_tockens=attn_metadata.max_query_len,
|
||||||
|
next_tockens=attn_metadata.max_query_len,
|
||||||
|
actual_seq_qlen=cum_seq_len,
|
||||||
|
actual_seq_kvlen=cum_seq_len,
|
||||||
|
)[0]
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@@ -1494,24 +1400,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for AscendAttentionBackendImpl")
|
" for AscendAttentionBackendImpl")
|
||||||
|
|
||||||
num_tokens = query.shape[0]
|
|
||||||
if attn_metadata is None:
|
|
||||||
return output
|
|
||||||
|
|
||||||
# NOTE: Currently, we have various attention paths for different
|
|
||||||
# scenarios, and not all of them are in-place operations. Therefore,
|
|
||||||
# we need to create a separate tensor to hold the attention result.
|
|
||||||
# In the future, we may consolidate them into fewer paths, which will
|
|
||||||
# hopefully allow us to use in-place operation by default.
|
|
||||||
intermediate_output: torch.Tensor
|
|
||||||
|
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
attn_type = self.attn_type
|
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
|
||||||
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
|
||||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"PallasAttentionBackendImpl")
|
"PallasAttentionBackendImpl")
|
||||||
|
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
if attn_metadata is None:
|
||||||
|
return output.fill_(0)
|
||||||
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
@@ -1558,48 +1456,25 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
if not forward_context.capturing:
|
if not forward_context.capturing:
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
intermediate_output = self._forward_pcp_dcp(
|
attn_output = self._forward_pcp_dcp(query, key, value,
|
||||||
query, key, value, kv_cache, attn_metadata, output)
|
kv_cache, attn_metadata,
|
||||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
output)
|
||||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
return output
|
||||||
intermediate_output = torch_npu.npu_fusion_attention(
|
if self.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
query,
|
attn_output = self._forward_encode(query, key, value,
|
||||||
key,
|
attn_metadata, output)
|
||||||
value,
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
head_num=self.num_heads,
|
return output
|
||||||
input_layout="TND",
|
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
scale=self.scale,
|
output = self._forward_decode_only(query, attn_metadata,
|
||||||
sparse_mode=4,
|
output)
|
||||||
atten_mask=attn_metadata.attn_mask,
|
|
||||||
pre_tockens=attn_metadata.max_query_len,
|
|
||||||
next_tockens=attn_metadata.max_query_len,
|
|
||||||
actual_seq_qlen=cum_seq_len,
|
|
||||||
actual_seq_kvlen=cum_seq_len,
|
|
||||||
)[0]
|
|
||||||
# V0-Style scheduler situation.
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
intermediate_output = self._forward_prefill_no_cache(
|
|
||||||
query, key, value, attn_metadata, output, num_tokens)
|
|
||||||
elif attn_metadata.attn_state == \
|
|
||||||
AscendAttentionState.PrefillCacheHit:
|
|
||||||
intermediate_output = self._forward_prefill_cache_hit(
|
|
||||||
query, attn_metadata, output)
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
||||||
intermediate_output = self._forward_decode_only(
|
|
||||||
query, attn_metadata, output)
|
|
||||||
# Normal V1 situation.
|
|
||||||
else:
|
else:
|
||||||
# npu_fused_infer_attention_score does not support cases
|
output = self._forward_prefill(query, key, value,
|
||||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
attn_metadata, output)
|
||||||
# Thus we need unpad it here.
|
|
||||||
num_tokens = attn_metadata.query_start_loc[-1]
|
|
||||||
query = query[:num_tokens]
|
|
||||||
intermediate_output = self._forward_v1_style(
|
|
||||||
query, attn_metadata, output)
|
|
||||||
else:
|
else:
|
||||||
intermediate_output, num_tokens = self.full_graph_attention(
|
attn_output, num_tokens = self.full_graph_attention(
|
||||||
query, key, value, kv_cache, attn_metadata, output)
|
query, key, value, kv_cache, attn_metadata, output)
|
||||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -979,25 +979,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# dcp situation.
|
# dcp situation.
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
if self.vllm_config.model_config.use_mla:
|
||||||
|
return None
|
||||||
# Pooling situation.
|
# Pooling situation.
|
||||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||||
# Chunk Prefill situation.
|
# fia prefill situation.
|
||||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
if attn_state in [
|
||||||
|
AscendAttentionState.PrefillNoCache,
|
||||||
|
AscendAttentionState.PrefillCacheHit,
|
||||||
|
AscendAttentionState.ChunkedPrefill
|
||||||
|
]:
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
|
||||||
# Prefill without cache situation.
|
|
||||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
max_seq_len = max(seq_lens.max().item(), 0)
|
|
||||||
return self.attn_mask_builder.get_attn_mask(
|
|
||||||
max_seq_len, self.dtype, self.device)
|
|
||||||
# Prefill with cache hit.
|
|
||||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask().to(
|
|
||||||
torch.bool)
|
|
||||||
# Decode-only situation.
|
# Decode-only situation.
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||||
# pcp situation.
|
# pcp situation.
|
||||||
|
|||||||
Reference in New Issue
Block a user