[Refactor]6/N Extract common code of class AscendMLAImpl (#5314)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
Eliminate duplicate code for two file(mla_v1.py mla_cp.py) of IMPL
classes.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
wujinyuan1
2025-12-28 10:40:45 +08:00
committed by GitHub
parent dbe4c338f2
commit 23169021d9
3 changed files with 268 additions and 514 deletions

View File

@@ -254,7 +254,7 @@ class TestAscendMLAImpl(TestBase):
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess_dcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_dcp_group):
@@ -339,7 +339,7 @@ class TestAscendMLAImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess_pcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_pcp_group,
@@ -543,8 +543,8 @@ class TestAscendMLAImpl(TestBase):
self.impl._v_up_proj.return_value = torch.randn(
B, self.impl.v_head_dim)
result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
BS, attn_metadata)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], self.impl.v_head_dim)
@@ -578,14 +578,14 @@ class TestAscendMLAImpl(TestBase):
def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int, max_seq_len: int,
chunk_size: int, chunk_idx: int, toks: int):
return torch.randn(sum_seq_len, allgatered_kv_c_normed.shape[1],
allgatered_kv_c_normed.shape[2]), torch.randn(
sum_seq_len, allgatered_k_pe.shape[1],
allgatered_k_pe.shape[2])
chunked_context: CPChunkedContextMetadata,
chunk_idx: int, toks: int):
return torch.randn(
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
allgatered_kv_c_normed.shape[1],
allgatered_kv_c_normed.shape[2]), torch.randn(
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
allgatered_k_pe.shape[1], allgatered_k_pe.shape[2])
# mock proj
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
@@ -679,10 +679,6 @@ class TestAscendMLAImpl(TestBase):
iters * (1 if dcp_size * pcp_size > 1 else 0))
self.assertEqual(mock_load.call_count, iters)
self.assertEqual(mock_ring.call_count, iters)
self.assertEqual(mock_dcp.all_gather.call_count,
(1 if dcp_size > 1 else 0))
self.assertEqual(mock_pcp.all_gather.call_count,
iters * (1 if pcp_size > 1 else 0))
mock_reorg.reset_mock()
mock_load.reset_mock()
mock_ring.reset_mock()
@@ -691,7 +687,18 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
def test_reorg_kvcache_with_dcp_pcp(self):
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):
def mock_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
BLOCK_SIZE = 128 # fixed
max_model_len = 4096
max_num_seqs = 25
@@ -706,6 +713,12 @@ class TestAscendMLAImpl(TestBase):
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
if pcp_size * dcp_size == 1:
continue
self.impl.dcp_size = dcp_size
self.impl.pcp_size = pcp_size
mock_dcp.all_gather = MagicMock(
side_effect=mock_all_gather(dcp_size))
mock_pcp.all_gather = MagicMock(
side_effect=mock_all_gather(pcp_size))
chunked_prefill_workspace_size = min(
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
128 * 1024)
@@ -723,27 +736,21 @@ class TestAscendMLAImpl(TestBase):
for i in range(len(chunked_context.seq_tot)):
allgatered_kv_c_normed = torch.randn(
chunked_context.seq_tot[i] * pcp_size * dcp_size,
self.impl.num_heads, self.impl.v_head_dim)
allgatered_k_pe = torch.randn(
chunked_context.seq_tot[i] * pcp_size * dcp_size,
self.impl.num_heads, self.impl.qk_rope_head_dim)
chunked_context.seq_tot[i], self.impl.num_heads,
self.impl.kv_lora_rank)
allgatered_k_pe = torch.randn(chunked_context.seq_tot[i],
self.impl.num_heads,
self.impl.qk_rope_head_dim)
result_kv, result_k_pe = self.impl._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=chunked_context.
padded_local_chunk_seq_lens[i],
local_context_lens_allranks=chunked_context.
local_context_lens_allranks,
sum_seq_len=chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=chunked_context.max_seq_lens[i],
chunk_size=chunked_context.chunk_size,
chunked_context,
chunk_idx=i,
toks=chunked_context.seq_tot[i],
)
self.assertEqual(result_kv.shape,
(chunked_context.cu_seq_lens_lst[i][-1],
self.impl.num_heads, self.impl.v_head_dim))
self.impl.num_heads, self.impl.kv_lora_rank))
self.assertEqual(
result_k_pe.shape,
(chunked_context.cu_seq_lens_lst[i][-1],
@@ -754,6 +761,11 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result_k_pe.shape[0],
chunked_context.cu_seq_lens_lst[i][-1])
self.assertEqual(mock_dcp.all_gather.call_count,
(1 if dcp_size > 1 else 0))
self.assertEqual(mock_pcp.all_gather.call_count,
(1 if pcp_size > 1 else 0))
def test_out_lse_reshape(self):
test_cases = [10, 1, 128, 512]
for test_case in test_cases:
@@ -1052,10 +1064,9 @@ class TestAscendMLAImpl(TestBase):
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
torch.ones(10, 10, dtype=torch.float16), 1)
output = self.impl._forward_prefill_cp(q_nope, q_pe, k_nope,
k_pe, value,
kv_c_and_k_pe_cache,
attn_metadata)
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
value, kv_c_and_k_pe_cache,
attn_metadata)
self.assertEqual(
output.shape,
(seq_len_q, self.impl.num_heads * self.impl.v_head_dim))