[Refactor]7/N Extract common code to common_cp (#5490)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to
common_cp.py.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

vLLM version: release/v0.13.0
vLLM main:
5fbfa8d9ef

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
wujinyuan1
2026-01-05 17:41:12 +08:00
committed by GitHub
parent 755caeb06e
commit 4a3663327b
10 changed files with 252 additions and 301 deletions

View File

@@ -7,8 +7,9 @@ from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.context_parallel.common_cp import (
CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse)
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
@@ -441,14 +442,14 @@ class TestAscendMLAImpl(TestBase):
decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool)
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_metadata)
result = _process_attn_out_lse(attn_output, softmax_lse,
decode_metadata.batch_seq_mask)
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
@@ -725,7 +726,15 @@ class TestAscendMLAImpl(TestBase):
assert torch.allclose(lse, expected_lse)
@patch('torch_npu.npu_attention_update')
def test_npu_attention_update_with_dcp_pcp(self,
@patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_npu_attention_update_with_dcp_pcp(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group,
mock_npu_attention_update):
NUM_TOKENS = 10 # fixed
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
@@ -752,10 +761,19 @@ class TestAscendMLAImpl(TestBase):
attn_lse_split_cp[0])
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
mock_pcp_group = MagicMock()
mock_pcp_group.world_size = self.impl.pcp_size
mock_get_pcp_group.return_value = mock_pcp_group
mock_dcp.world_size = self.impl.dcp_size
mock_dcp_group = MagicMock()
# mock_dcp_group.world_size = self.impl.dcp_size
mock_get_dcp_group.return_value = mock_dcp_group
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
self.impl.dcp_size * num_heads,
head_dim)
out = self.impl._npu_attention_update(attn_out_lse)
out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse)
self.impl.dcp_size = 1
self.impl.pcp_size = 1
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
@@ -873,8 +891,8 @@ class TestAscendMLAImpl(TestBase):
decode_meta = MagicMock()
decode_meta.batch_seq_mask = batch_seq_mask
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_meta)
result = _process_attn_out_lse(attn_output, softmax_lse,
batch_seq_mask)
# [PCP * S, DCP * H, D + 1]
self.assertIsInstance(result, torch.Tensor)
assert result.shape == (B * self.impl.pcp_size, H, D + 1)