[Refactor]6/N Extract common code of class AscendMLAImpl (#5314)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_v1.py mla_cp.py) of IMPL classes. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: release/v0.13.0 - vLLM main:5fbfa8d9ef--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -254,7 +254,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess_dcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_dcp_group):
|
||||
@@ -339,7 +339,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess_pcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_pcp_group,
|
||||
@@ -543,8 +543,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl._v_up_proj.return_value = torch.randn(
|
||||
B, self.impl.v_head_dim)
|
||||
|
||||
result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
|
||||
BS, attn_metadata)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
||||
@@ -578,14 +578,14 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
padded_local_chunk_seq_lens_lst: list[int],
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int, max_seq_len: int,
|
||||
chunk_size: int, chunk_idx: int, toks: int):
|
||||
return torch.randn(sum_seq_len, allgatered_kv_c_normed.shape[1],
|
||||
allgatered_kv_c_normed.shape[2]), torch.randn(
|
||||
sum_seq_len, allgatered_k_pe.shape[1],
|
||||
allgatered_k_pe.shape[2])
|
||||
chunked_context: CPChunkedContextMetadata,
|
||||
chunk_idx: int, toks: int):
|
||||
return torch.randn(
|
||||
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
||||
allgatered_kv_c_normed.shape[1],
|
||||
allgatered_kv_c_normed.shape[2]), torch.randn(
|
||||
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
||||
allgatered_k_pe.shape[1], allgatered_k_pe.shape[2])
|
||||
|
||||
# mock proj
|
||||
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
|
||||
@@ -679,10 +679,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
iters * (1 if dcp_size * pcp_size > 1 else 0))
|
||||
self.assertEqual(mock_load.call_count, iters)
|
||||
self.assertEqual(mock_ring.call_count, iters)
|
||||
self.assertEqual(mock_dcp.all_gather.call_count,
|
||||
(1 if dcp_size > 1 else 0))
|
||||
self.assertEqual(mock_pcp.all_gather.call_count,
|
||||
iters * (1 if pcp_size > 1 else 0))
|
||||
mock_reorg.reset_mock()
|
||||
mock_load.reset_mock()
|
||||
mock_ring.reset_mock()
|
||||
@@ -691,7 +687,18 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
def test_reorg_kvcache_with_dcp_pcp(self):
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
BLOCK_SIZE = 128 # fixed
|
||||
max_model_len = 4096
|
||||
max_num_seqs = 25
|
||||
@@ -706,6 +713,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
||||
if pcp_size * dcp_size == 1:
|
||||
continue
|
||||
self.impl.dcp_size = dcp_size
|
||||
self.impl.pcp_size = pcp_size
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
chunked_prefill_workspace_size = min(
|
||||
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
|
||||
128 * 1024)
|
||||
@@ -723,27 +736,21 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
for i in range(len(chunked_context.seq_tot)):
|
||||
allgatered_kv_c_normed = torch.randn(
|
||||
chunked_context.seq_tot[i] * pcp_size * dcp_size,
|
||||
self.impl.num_heads, self.impl.v_head_dim)
|
||||
allgatered_k_pe = torch.randn(
|
||||
chunked_context.seq_tot[i] * pcp_size * dcp_size,
|
||||
self.impl.num_heads, self.impl.qk_rope_head_dim)
|
||||
chunked_context.seq_tot[i], self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
allgatered_k_pe = torch.randn(chunked_context.seq_tot[i],
|
||||
self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
result_kv, result_k_pe = self.impl._reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
padded_local_chunk_seq_lens_lst=chunked_context.
|
||||
padded_local_chunk_seq_lens[i],
|
||||
local_context_lens_allranks=chunked_context.
|
||||
local_context_lens_allranks,
|
||||
sum_seq_len=chunked_context.cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=chunked_context.max_seq_lens[i],
|
||||
chunk_size=chunked_context.chunk_size,
|
||||
chunked_context,
|
||||
chunk_idx=i,
|
||||
toks=chunked_context.seq_tot[i],
|
||||
)
|
||||
self.assertEqual(result_kv.shape,
|
||||
(chunked_context.cu_seq_lens_lst[i][-1],
|
||||
self.impl.num_heads, self.impl.v_head_dim))
|
||||
self.impl.num_heads, self.impl.kv_lora_rank))
|
||||
self.assertEqual(
|
||||
result_k_pe.shape,
|
||||
(chunked_context.cu_seq_lens_lst[i][-1],
|
||||
@@ -754,6 +761,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result_k_pe.shape[0],
|
||||
chunked_context.cu_seq_lens_lst[i][-1])
|
||||
|
||||
self.assertEqual(mock_dcp.all_gather.call_count,
|
||||
(1 if dcp_size > 1 else 0))
|
||||
self.assertEqual(mock_pcp.all_gather.call_count,
|
||||
(1 if pcp_size > 1 else 0))
|
||||
|
||||
def test_out_lse_reshape(self):
|
||||
test_cases = [10, 1, 128, 512]
|
||||
for test_case in test_cases:
|
||||
@@ -1052,10 +1064,9 @@ class TestAscendMLAImpl(TestBase):
|
||||
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
|
||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
||||
|
||||
output = self.impl._forward_prefill_cp(q_nope, q_pe, k_nope,
|
||||
k_pe, value,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
|
||||
value, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
(seq_len_q, self.impl.num_heads * self.impl.v_head_dim))
|
||||
|
||||
Reference in New Issue
Block a user