[Refactor]7/N Extract common code to common_cp (#5490)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to
common_cp.py.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

vLLM version: release/v0.13.0
vLLM main:
5fbfa8d9ef

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
wujinyuan1
2026-01-05 17:41:12 +08:00
committed by GitHub
parent 755caeb06e
commit 4a3663327b
10 changed files with 252 additions and 301 deletions

View File

@@ -5,9 +5,11 @@ import torch
from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForPrefill)
from vllm_ascend.attention.attention_v1 import AscendMetadata
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForPrefill, AscendPCPMetadata)
class TestAscendAttentionCPImpl(TestBase):
@@ -82,25 +84,22 @@ class TestAscendAttentionCPImpl(TestBase):
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
@patch('torch_npu.npu_attention_update')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
@patch(
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
)
@patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
mock_get_forward_context,
mock_npu_fused_infer_attention_score):
query = torch.randn(2, 4, 128)
self.impl.key_cache = torch.randn(100, 128, 1, 128)
self.impl.value_cache = torch.randn(100, 128, 1, 128)
mock_npu_fused_infer_attention_score,
mock_npu_attention_update):
query = torch.randn(2, 4, 64)
self.impl.key_cache = torch.randn(100, 64, 1, 64)
self.impl.value_cache = torch.randn(100, 64, 1, 64)
def mock_npu_attention_update(attn_out_lse_list):
mock_output = torch.randn(
attn_out_lse_list.shape[0] // mock_pcp.world_size,
attn_out_lse_list.shape[1] // mock_dcp.world_size,
attn_out_lse_list.shape[2] - 1)
return mock_output
self.impl._npu_attention_update = MagicMock()
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
# Mock output
mock_npu_attention_update.return_value = (torch.randn(2 * 4, 64), None)
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -116,12 +115,11 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)
self.assertEqual(output.shape[2], 64)
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_prefill_query_all_gather(self):
@@ -249,7 +247,7 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.slot_mapping = torch.randn(2)
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])
key = torch.randn(num_tokens, num_heads, head_size)
@@ -336,7 +334,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
attn_metadata.num_actual_tokens = self.q_total_tokens
prefill_metadata = AscendMetadataForPrefill()
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata()
pcp_metadata = AscendPCPMetadata()
pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum
pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
@@ -409,7 +407,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
@patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
)
def test_attention_with_nomask_and_mask_chunk(
self, mock_update_out_and_lse,
@@ -457,7 +455,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
@patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
)
def test_attention_with_nomask_and_mask_nochunk(
self, mock_npu_attn_out_lse_update,
@@ -505,7 +503,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
self.assertEqual(attn_lse, None)
@patch(
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
)
def test_update_chunk_attn_out_lse_with_current_attn_out_lse(
self, mock_npu_attn_out_lse_update):