[Refactor] Remove redundant attention operator branches. (#4531)
[Refactor] Remove redundant attention operator branches. Reason: We replace other attention ops with fused_infer_attention_score expect decode_only state. clean code and remove 310P support. https://github.com/vllm-project/vllm-ascend/pull/4455 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -25,12 +25,6 @@ class TestAscendAttentionBackend(TestBase):
|
||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||
AscendAttentionMetadataBuilder)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
def test_get_kv_cache_shape_310p(self, mock_soc_version):
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
|
||||
@@ -95,76 +89,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_2d')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
|
||||
mock_npu_format_cast,
|
||||
mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
seq_lens_cpu=torch.tensor([5, 6]),
|
||||
num_reqs=2,
|
||||
num_actual_tokens=10,
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_spec')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
|
||||
def test_build_chunked_prefill(self, mock_ascend_attention_state,
|
||||
mock_soc_version, mock_nd_to_nz_spec,
|
||||
mock_npu_format_cast, mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=15,
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_ascend_attention_state = MagicMock()
|
||||
mock_ascend_attention_state.PrefillNoCache = 0
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@@ -286,73 +210,40 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
def test_forward_prefill(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.actual_seq_lengths_q = [10]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
|
||||
10, 8, 64), torch.ones(10, 8, 64))
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@@ -454,119 +345,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_soc_version, mock_get_forward_context):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
query = torch.randn(10, 8 * 192)
|
||||
key = torch.randn(10, 8 * 192)
|
||||
value = torch.randn(10, 8 * 192)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
layer = self.layer_no_quant
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
output = self.impl_192.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_310p_device(self, mock_get_forward_context,
|
||||
mock_soc_version,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
"""Test forward pass on 310P device"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
|
||||
@@ -41,11 +41,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d,
|
||||
nd_to_nz_spec, prefill_context_parallel_enable,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
@@ -83,9 +79,6 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -351,16 +344,6 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
if attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
mask_nz = nd_to_nz_spec(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
prefill_metadata = None
|
||||
decode_metadata = None
|
||||
@@ -585,9 +568,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output: torch.Tensor,
|
||||
num_tokens=0):
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
return intermediate_output, query.shape[0]
|
||||
attn_output = self._forward_pcp_dcp(query, key, value, kv_cache,
|
||||
attn_metadata, output)
|
||||
return attn_output, query.shape[0]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
@@ -688,93 +671,58 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output, num_tokens
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
num_tokens=0,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens]
|
||||
|
||||
def _forward_prefill_cache_hit(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
|
||||
if block_size == 128:
|
||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
||||
# be utilized in a wider range in upcoming versions.
|
||||
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=compress_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# chunked_prefill.
|
||||
else:
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
query = query[:num_tokens]
|
||||
# Prepare tensors for attention output
|
||||
# TODO: Refactor this to step-level instead of layer-level
|
||||
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
@@ -783,10 +731,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||
0] == query.size(0):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
@@ -827,69 +771,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario.
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario.
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables, cu_seqlen_q,
|
||||
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
|
||||
self.scale, None, True)
|
||||
return output
|
||||
|
||||
# Use paged attention.
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# Do reformat in case of broadcasted tensors.
|
||||
attn_metadata.attn_mask = \
|
||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
|
||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
||||
# be utilized in a wider range in upcoming versions.
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
return output
|
||||
|
||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||
q_seqlens: List[int],
|
||||
k_nomask: torch.Tensor,
|
||||
@@ -1464,6 +1345,31 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
)
|
||||
return key, value
|
||||
|
||||
def _forward_encode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@@ -1494,24 +1400,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
"fused output quantization is not yet supported"
|
||||
" for AscendAttentionBackendImpl")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output
|
||||
|
||||
# NOTE: Currently, we have various attention paths for different
|
||||
# scenarios, and not all of them are in-place operations. Therefore,
|
||||
# we need to create a separate tensor to hold the attention result.
|
||||
# In the future, we may consolidate them into fewer paths, which will
|
||||
# hopefully allow us to use in-place operation by default.
|
||||
intermediate_output: torch.Tensor
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
||||
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
@@ -1558,48 +1456,25 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if not forward_context.capturing:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
intermediate_output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
intermediate_output = self._forward_decode_only(
|
||||
query, attn_metadata, output)
|
||||
# Normal V1 situation.
|
||||
attn_output = self._forward_pcp_dcp(query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
if self.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_output = self._forward_encode(query, key, value,
|
||||
attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
output = self._forward_decode_only(query, attn_metadata,
|
||||
output)
|
||||
else:
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
intermediate_output = self._forward_v1_style(
|
||||
query, attn_metadata, output)
|
||||
output = self._forward_prefill(query, key, value,
|
||||
attn_metadata, output)
|
||||
else:
|
||||
intermediate_output, num_tokens = self.full_graph_attention(
|
||||
attn_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
|
||||
return output
|
||||
|
||||
@@ -979,25 +979,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# dcp situation.
|
||||
if self.dcp_size > 1:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
return None
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
# fia prefill situation.
|
||||
if attn_state in [
|
||||
AscendAttentionState.PrefillNoCache,
|
||||
AscendAttentionState.PrefillCacheHit,
|
||||
AscendAttentionState.ChunkedPrefill
|
||||
]:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask().to(
|
||||
torch.bool)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
|
||||
Reference in New Issue
Block a user