[Refactor]7/N Extract common code to common_cp (#5490)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to
common_cp.py.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

vLLM version: release/v0.13.0
vLLM main:
5fbfa8d9ef

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
wujinyuan1
2026-01-05 17:41:12 +08:00
committed by GitHub
parent 755caeb06e
commit 4a3663327b
10 changed files with 252 additions and 301 deletions

View File

@@ -5,9 +5,11 @@ import torch
from tests.ut.attention.utils import patch_distributed_groups from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl from vllm_ascend.attention.attention_v1 import AscendMetadata
from vllm_ascend.attention.attention_v1 import (AscendMetadata, from vllm_ascend.attention.context_parallel.attention_cp import \
AscendMetadataForPrefill) AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForPrefill, AscendPCPMetadata)
class TestAscendAttentionCPImpl(TestBase): class TestAscendAttentionCPImpl(TestBase):
@@ -82,25 +84,22 @@ class TestAscendAttentionCPImpl(TestBase):
self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128) self.assertEqual(output.shape[2], 128)
@patch('torch_npu.npu_attention_update')
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score")
@patch('vllm_ascend.attention.attention_cp.get_forward_context') @patch(
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
)
@patch_distributed_groups(dcp_size=2, pcp_size=2) @patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp, def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
mock_get_forward_context, mock_get_forward_context,
mock_npu_fused_infer_attention_score): mock_npu_fused_infer_attention_score,
query = torch.randn(2, 4, 128) mock_npu_attention_update):
self.impl.key_cache = torch.randn(100, 128, 1, 128) query = torch.randn(2, 4, 64)
self.impl.value_cache = torch.randn(100, 128, 1, 128) self.impl.key_cache = torch.randn(100, 64, 1, 64)
self.impl.value_cache = torch.randn(100, 64, 1, 64)
def mock_npu_attention_update(attn_out_lse_list): # Mock output
mock_output = torch.randn( mock_npu_attention_update.return_value = (torch.randn(2 * 4, 64), None)
attn_out_lse_list.shape[0] // mock_pcp.world_size,
attn_out_lse_list.shape[1] // mock_dcp.world_size,
attn_out_lse_list.shape[2] - 1)
return mock_output
self.impl._npu_attention_update = MagicMock()
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
mock_get_forward_context.return_value = MagicMock(capturing=False) mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -116,12 +115,11 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.decode_meta = MagicMock() attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor( attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool) [1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
self.assertEqual(output.shape[0], 2) self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128) self.assertEqual(output.shape[2], 64)
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_prefill_query_all_gather(self): def test_prefill_query_all_gather(self):
@@ -249,7 +247,7 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.slot_mapping = torch.randn(2) attn_metadata.slot_mapping = torch.randn(2)
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
attn_metadata.prefill = MagicMock() attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor( attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0]) [0, 3, 1, 2, 0, 0, 0, 0])
key = torch.randn(num_tokens, num_heads, head_size) key = torch.randn(num_tokens, num_heads, head_size)
@@ -336,7 +334,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
attn_metadata.num_actual_tokens = self.q_total_tokens attn_metadata.num_actual_tokens = self.q_total_tokens
prefill_metadata = AscendMetadataForPrefill() prefill_metadata = AscendMetadataForPrefill()
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata() pcp_metadata = AscendPCPMetadata()
pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum
pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
@@ -409,7 +407,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
@patch('torch.ops.npu.npu_fused_infer_attention_score') @patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch( @patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse' 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
) )
def test_attention_with_nomask_and_mask_chunk( def test_attention_with_nomask_and_mask_chunk(
self, mock_update_out_and_lse, self, mock_update_out_and_lse,
@@ -457,7 +455,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
@patch('torch.ops.npu.npu_fused_infer_attention_score') @patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch( @patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
) )
def test_attention_with_nomask_and_mask_nochunk( def test_attention_with_nomask_and_mask_nochunk(
self, mock_npu_attn_out_lse_update, self, mock_npu_attn_out_lse_update,
@@ -505,7 +503,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
self.assertEqual(attn_lse, None) self.assertEqual(attn_lse, None)
@patch( @patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
) )
def test_update_chunk_attn_out_lse_with_current_attn_out_lse( def test_update_chunk_attn_out_lse_with_current_attn_out_lse(
self, mock_npu_attn_out_lse_update): self, mock_npu_attn_out_lse_update):

View File

@@ -7,8 +7,9 @@ from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata from vllm_ascend.attention.context_parallel.common_cp import (
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse)
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
@@ -441,14 +442,14 @@ class TestAscendMLAImpl(TestBase):
decode_metadata.batch_seq_mask = torch.tensor([True, False], decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool) dtype=torch.bool)
result = self.impl._process_attn_out_lse(attn_output, softmax_lse, result = _process_attn_out_lse(attn_output, softmax_lse,
decode_metadata) decode_metadata.batch_seq_mask)
self.assertEqual(result.shape[0], B * self.impl.pcp_size) self.assertEqual(result.shape[0], B * self.impl.pcp_size)
self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch('vllm_ascend.attention.mla_cp.get_forward_context') @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention") @patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch('torch_npu.npu_attention_update') @patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
@@ -725,7 +726,15 @@ class TestAscendMLAImpl(TestBase):
assert torch.allclose(lse, expected_lse) assert torch.allclose(lse, expected_lse)
@patch('torch_npu.npu_attention_update') @patch('torch_npu.npu_attention_update')
def test_npu_attention_update_with_dcp_pcp(self, @patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_npu_attention_update_with_dcp_pcp(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group,
mock_npu_attention_update): mock_npu_attention_update):
NUM_TOKENS = 10 # fixed NUM_TOKENS = 10 # fixed
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)] test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
@@ -752,10 +761,19 @@ class TestAscendMLAImpl(TestBase):
attn_lse_split_cp[0]) attn_lse_split_cp[0])
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
mock_pcp_group = MagicMock()
mock_pcp_group.world_size = self.impl.pcp_size
mock_get_pcp_group.return_value = mock_pcp_group
mock_dcp.world_size = self.impl.dcp_size
mock_dcp_group = MagicMock()
# mock_dcp_group.world_size = self.impl.dcp_size
mock_get_dcp_group.return_value = mock_dcp_group
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS, attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
self.impl.dcp_size * num_heads, self.impl.dcp_size * num_heads,
head_dim) head_dim)
out = self.impl._npu_attention_update(attn_out_lse) out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse)
self.impl.dcp_size = 1 self.impl.dcp_size = 1
self.impl.pcp_size = 1 self.impl.pcp_size = 1
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank) assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
@@ -873,8 +891,8 @@ class TestAscendMLAImpl(TestBase):
decode_meta = MagicMock() decode_meta = MagicMock()
decode_meta.batch_seq_mask = batch_seq_mask decode_meta.batch_seq_mask = batch_seq_mask
result = self.impl._process_attn_out_lse(attn_output, softmax_lse, result = _process_attn_out_lse(attn_output, softmax_lse,
decode_meta) batch_seq_mask)
# [PCP * S, DCP * H, D + 1] # [PCP * S, DCP * H, D + 1]
self.assertIsInstance(result, torch.Tensor) self.assertIsInstance(result, torch.Tensor)
assert result.shape == (B * self.impl.pcp_size, H, D + 1) assert result.shape == (B * self.impl.pcp_size, H, D + 1)

View File

@@ -34,10 +34,10 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForDecode, AscendMetadataForPrefill)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendMetadataForDecode, enable_cp, split_decodes_and_prefills,
AscendMetadataForPrefill, enable_cp,
split_decodes_and_prefills,
using_paged_attention) using_paged_attention)
from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params, get_draft_graph_params, get_graph_params,
@@ -63,7 +63,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
if enable_cp(): if enable_cp():
from vllm_ascend.attention.attention_cp import \ from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl AscendAttentionCPImpl
return AscendAttentionCPImpl return AscendAttentionCPImpl
return AscendAttentionBackendImpl return AscendAttentionBackendImpl
@@ -71,7 +71,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
if enable_cp(): if enable_cp():
from vllm_ascend.attention.attention_cp import \ from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPMetadataBuilder AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder
return AscendAttentionMetadataBuilder return AscendAttentionMetadataBuilder

View File

@@ -1,40 +0,0 @@
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass
class CPChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None

View File

@@ -33,9 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl, from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder, AscendAttentionMetadataBuilder,
AscendMetadata) AscendMetadata)
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForDecode, AscendMetadataForPrefill, AscendPCPMetadata,
_npu_attention_update, _process_attn_out_lse)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendMetadataForDecode,
AscendMetadataForPrefill,
filter_chunked_req_indices, filter_chunked_req_indices,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_graph_params,
@@ -196,7 +197,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
tail_attn_nomask_seqlens = torch.cumsum( tail_attn_nomask_seqlens = torch.cumsum(
tail_attn_nomask_seqlens[1], dim=0).tolist() tail_attn_nomask_seqlens[1], dim=0).tolist()
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata( pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata. kv_with_q_head_nomask_idx=common_long_seq_metadata.
@@ -211,13 +212,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
head_attn_nomask_seqlens=head_attn_nomask_seqlens, head_attn_nomask_seqlens=head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx, q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask) pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)
prefill_metadata = AscendMetadataForPrefill( prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata, pcp_metadata=pcp_metadata,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx
if common_long_seq_metadata is not None else None,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
block_tables=block_table[num_decodes:], block_tables=block_table[num_decodes:],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0)) actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
@@ -460,39 +460,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse return attn_out, attn_lse
def _npu_attention_update(self,
attn_out_lse: torch.Tensor) -> torch.Tensor:
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // self.pcp_size
H = H_total // self.dcp_size
D = self.head_size
update_type = 0
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
# [N, S, H, D], [N, S, H, 1]
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1)
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2)
lse_flat = lse_flat.squeeze(-1).flatten(1)
# unbind to list
# [S*H, D]
out_list = out_flat.unbind(0)
# [S*H]
lse_list = lse_flat.unbind(0)
attn_out, attn_lse = torch_npu.npu_attention_update(
lse_list, out_list, update_type)
attn_out = attn_out.view(S, H, D)
return attn_out
def _forward_decode_pcp_dcp(self, query: torch.Tensor, def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor: attn_metadata: AscendMetadata) -> torch.Tensor:
assert self.key_cache is not None assert self.key_cache is not None
@@ -580,33 +547,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
else: else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
query, k_nope, value, **common_kwargs) query, k_nope, value, **common_kwargs)
attn_out_lse = _process_attn_out_lse(
out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask)
None].expand_as( attn_out = _npu_attention_update(self.head_size, attn_out_lse)
attn_out)
attn_out = torch.where(out_mask, 0, attn_out)
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
None].expand_as(
attn_lse)
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
if self.dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(
attn_out_lse.contiguous(), dim=0)
attn_out = self._npu_attention_update(attn_out_lse)
return attn_out return attn_out
def _update_out_and_lse(self, out_list: torch.Tensor, def _update_out_and_lse(self, out_list: torch.Tensor,
@@ -780,7 +723,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
all_kv = get_pcp_group().all_gather( all_kv = get_pcp_group().all_gather(
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
all_kv = torch.index_select(all_kv, 0, all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx) pcp_allgather_restore_idx)
key, value = all_kv.split([self.head_size, self.head_size], key, value = all_kv.split([self.head_size, self.head_size],

View File

@@ -0,0 +1,137 @@
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_world_size,
get_pcp_group)
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass
class CPChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None
@dataclass
class AscendMetadataForPrefill:
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
batch_chunk_seq_mask: Optional[list[bool]] = None
local_total_toks: Optional[int] = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
chunked_context: Optional[ChunkedContextMetadata] = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
batch_seq_mask: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
out_mask = batch_seq_mask[:, None, None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(),
dim=0)
return attn_out_lse
def _npu_attention_update(head_size,
attn_out_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
# [PCP * S, DCP * H, D+1]
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // pcp_size
H = H_total // dcp_size
D = head_size
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1],
dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
# unbind to list
out_list = out_flat.unbind(0) # [S*H, D]
lse_list = lse_flat.unbind(0) # [S*H]
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
attn_out = attn_out.view(-1, H, D)
return attn_out

View File

@@ -2,7 +2,6 @@ from typing import Optional, Tuple, TypeVar
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch_npu import torch_npu
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
@@ -15,17 +14,17 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
# isort: off # isort: off
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, from vllm_ascend.attention.mla_v1 import (
AscendMLAImpl, AscendMLAMetadata, AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata,
AscendMLAPrefillMetadata, DecodeMLAPreprocessResult, PrefillMLAPreprocessResult,
DecodeMLAPreprocessResult, BUILD_METADATA_STEP_PREFILL)
PrefillMLAPreprocessResult)
#isort: on #isort: on
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
from vllm_ascend.attention.common_cp import (AscendPCPMetadata, from vllm_ascend.attention.context_parallel.common_cp import (
CPChunkedContextMetadata) AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse,
_npu_attention_update)
from vllm_ascend.compilation.acl_graph import (get_draft_graph_params, from vllm_ascend.compilation.acl_graph import (get_draft_graph_params,
get_graph_params, get_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
@@ -89,6 +88,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
if long_seq_metadata is None: if long_seq_metadata is None:
raise AssertionError("long_seq_metadata should not be None.") raise AssertionError("long_seq_metadata should not be None.")
# In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
self.num_actual_tokens = max( self.num_actual_tokens = max(
long_seq_metadata.num_actual_tokens_pcp_padded, long_seq_metadata.num_actual_tokens_pcp_padded,
common_attn_metadata.num_actual_tokens) common_attn_metadata.num_actual_tokens)
@@ -187,21 +188,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
) )
return chunked_metadata return chunked_metadata
def set_prefill_block_table( def get_block_table_size(
self, self, common_attn_metadata: AscendCommonAttentionMetadata,
common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
):
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum( self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
).item() ).item()
self.block_table = common_attn_metadata.block_table_tensor[:self. if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
num_decodes_flatten # For pcp + spec decode, we flatten seq_lens and block_table
+ self. # to avoid irregular spec_attn_mask shape
num_prefills] return self.num_decodes_flatten + self.num_prefills
else:
def set_decode_block_table(self): return self.num_decodes_flatten
self.block_table = self.block_table[:self.num_decodes_flatten, ...]
def build_prefill_metadata( def build_prefill_metadata(
self, self,
@@ -637,39 +634,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
lse=softmax_lse) lse=softmax_lse)
# Update out&lse # Update out&lse
attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse, attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse,
decode_meta) decode_meta.batch_seq_mask)
attn_output = self._npu_attention_update(attn_out_lse) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse)
return self._v_up_proj(attn_output) return self._v_up_proj(attn_output)
def _npu_attention_update(self,
attn_out_lse: torch.Tensor) -> torch.Tensor:
# [PCP * S, DCP * H, D+1]
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // self.pcp_size
H = H_total // self.dcp_size
D = self.kv_lora_rank
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1],
dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
# unbind to list
out_list = out_flat.unbind(0) # [S*H, D]
lse_list = lse_flat.unbind(0) # [S*H]
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
attn_out = attn_out.view(-1, H, D)
return attn_out
def _out_lse_reshape(self, attn_out: torch.Tensor, def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor: attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view( attn_out = attn_out.contiguous().view(
@@ -678,39 +647,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse return attn_out, attn_lse
def _process_attn_out_lse(
self,
attn_output: torch.Tensor,
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> torch.Tensor:
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(
attn_out_lse.contiguous(), dim=0)
return attn_out_lse
def _reorg_kvcache( def _reorg_kvcache(
self, self,
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,

View File

@@ -19,8 +19,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import (AscendPCPMetadata, from vllm_ascend.attention.context_parallel.common_cp import (
CPChunkedContextMetadata) AscendPCPMetadata, CPChunkedContextMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp, enable_cp,
maybe_save_kv_layer_to_connector, maybe_save_kv_layer_to_connector,
@@ -46,6 +46,8 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
BUILD_METADATA_STEP_PREFILL = 0
BUILD_METADATA_STEP_DECODE = 1
class AscendMLABackend(AttentionBackend): class AscendMLABackend(AttentionBackend):
@@ -61,7 +63,8 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls(): def get_builder_cls():
if enable_cp(): if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder from vllm_ascend.attention.context_parallel.mla_cp import \
AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder return AscendMLAMetadataBuilder
@@ -73,7 +76,8 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]: def get_impl_cls() -> Type["MLAAttentionImpl"]:
if enable_cp(): if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl from vllm_ascend.attention.context_parallel.mla_cp import \
AscendMlaCPImpl
return AscendMlaCPImpl return AscendMlaCPImpl
return AscendMLAImpl return AscendMLAImpl
@@ -418,7 +422,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.query_lens = query_seq_lens_cpu[:num_reqs] self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
self.set_prefill_block_table(common_attn_metadata) self.graph_pad_size = common_attn_metadata.graph_pad_size
block_table_size = self.get_block_table_size(
common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
self.block_table = common_attn_metadata.block_table_tensor[:
block_table_size]
prefill_metadata = None prefill_metadata = None
if self.num_prefills > 0: if self.num_prefills > 0:
@@ -499,23 +507,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
workspace=self.chunked_prefill_workspace, workspace=self.chunked_prefill_workspace,
) )
def set_prefill_block_table( def get_block_table_size(
self, self, common_attn_metadata: AscendCommonAttentionMetadata,
common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
): if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# If graph_pad_size > -1, mean is running in fullgraph mode. # If graph_pad_size > -1, mean is running in fullgraph mode.
self.graph_pad_size = common_attn_metadata.graph_pad_size # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch: return self.graph_pad_size
self.block_table = ( return common_attn_metadata.num_reqs
common_attn_metadata.block_table_tensor[:self.graph_pad_size]) return self.num_decodes
else:
self.block_table = (
common_attn_metadata.block_table_tensor[:common_attn_metadata.
num_reqs])
def set_decode_block_table(self):
self.block_table = self.block_table[:self.num_decodes, ...]
def build_prefill_metadata( def build_prefill_metadata(
self, self,
@@ -574,7 +575,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.seq_lens = self.seq_lens[:self.num_decodes] self.seq_lens = self.seq_lens[:self.num_decodes]
input_positions = input_positions[:self.num_decode_tokens] input_positions = input_positions[:self.num_decode_tokens]
self.set_decode_block_table() block_table_size = self.get_block_table_size(
common_attn_metadata, BUILD_METADATA_STEP_DECODE)
self.block_table = self.block_table[:block_table_size]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.

View File

@@ -35,52 +35,6 @@ def enable_cp():
or prefill_config.decode_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1
@dataclass
class AscendMetadataForPrefill:
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
batch_chunk_seq_mask: Optional[list[bool]] = None
local_total_toks: Optional[int] = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
chunked_context: Optional[ChunkedContextMetadata] = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
@dataclass @dataclass
# class AscendCommonLongSequenceMetadata: # class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata: class AscendPrefillContextParallelMetadata: