[Refactor]7/N Extract common code to common_cp (#5490)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to common_cp.py. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19cvLLM version: release/v0.13.0 vLLM main:5fbfa8d9ef- vLLM version: v0.13.0 - vLLM main:5326c89803--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
@@ -5,9 +5,11 @@ import torch
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForPrefill)
|
||||
from vllm_ascend.attention.attention_v1 import AscendMetadata
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||
AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendMetadataForPrefill, AscendPCPMetadata)
|
||||
|
||||
|
||||
class TestAscendAttentionCPImpl(TestBase):
|
||||
@@ -82,25 +84,22 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
|
||||
@patch(
|
||||
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
|
||||
)
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||
mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score):
|
||||
query = torch.randn(2, 4, 128)
|
||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_attention_update):
|
||||
query = torch.randn(2, 4, 64)
|
||||
self.impl.key_cache = torch.randn(100, 64, 1, 64)
|
||||
self.impl.value_cache = torch.randn(100, 64, 1, 64)
|
||||
|
||||
def mock_npu_attention_update(attn_out_lse_list):
|
||||
mock_output = torch.randn(
|
||||
attn_out_lse_list.shape[0] // mock_pcp.world_size,
|
||||
attn_out_lse_list.shape[1] // mock_dcp.world_size,
|
||||
attn_out_lse_list.shape[2] - 1)
|
||||
return mock_output
|
||||
|
||||
self.impl._npu_attention_update = MagicMock()
|
||||
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
|
||||
# Mock output
|
||||
mock_npu_attention_update.return_value = (torch.randn(2 * 4, 64), None)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
@@ -116,12 +115,11 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
attn_metadata.decode_meta = MagicMock()
|
||||
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
|
||||
[1, 0], dtype=torch.bool)
|
||||
|
||||
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
|
||||
|
||||
self.assertEqual(output.shape[0], 2)
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
self.assertEqual(output.shape[2], 64)
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_prefill_query_all_gather(self):
|
||||
@@ -249,7 +247,7 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
attn_metadata.slot_mapping = torch.randn(2)
|
||||
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
|
||||
attn_metadata.prefill = MagicMock()
|
||||
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
|
||||
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
|
||||
[0, 3, 1, 2, 0, 0, 0, 0])
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size)
|
||||
@@ -336,7 +334,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
attn_metadata.num_actual_tokens = self.q_total_tokens
|
||||
|
||||
prefill_metadata = AscendMetadataForPrefill()
|
||||
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata()
|
||||
pcp_metadata = AscendPCPMetadata()
|
||||
pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum
|
||||
pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
||||
pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
||||
@@ -409,7 +407,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||
@patch(
|
||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
|
||||
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
|
||||
)
|
||||
def test_attention_with_nomask_and_mask_chunk(
|
||||
self, mock_update_out_and_lse,
|
||||
@@ -457,7 +455,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||
@patch(
|
||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||
)
|
||||
def test_attention_with_nomask_and_mask_nochunk(
|
||||
self, mock_npu_attn_out_lse_update,
|
||||
@@ -505,7 +503,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.assertEqual(attn_lse, None)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||
)
|
||||
def test_update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||
self, mock_npu_attn_out_lse_update):
|
||||
|
||||
@@ -7,8 +7,9 @@ from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
|
||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse)
|
||||
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
||||
|
||||
|
||||
@@ -441,14 +442,14 @@ class TestAscendMLAImpl(TestBase):
|
||||
decode_metadata.batch_seq_mask = torch.tensor([True, False],
|
||||
dtype=torch.bool)
|
||||
|
||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_metadata)
|
||||
result = _process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_metadata.batch_seq_mask)
|
||||
|
||||
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
||||
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
@@ -725,7 +726,15 @@ class TestAscendMLAImpl(TestBase):
|
||||
assert torch.allclose(lse, expected_lse)
|
||||
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
def test_npu_attention_update_with_dcp_pcp(self,
|
||||
@patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_npu_attention_update_with_dcp_pcp(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group,
|
||||
mock_npu_attention_update):
|
||||
NUM_TOKENS = 10 # fixed
|
||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
|
||||
@@ -752,10 +761,19 @@ class TestAscendMLAImpl(TestBase):
|
||||
attn_lse_split_cp[0])
|
||||
|
||||
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
||||
|
||||
mock_pcp_group = MagicMock()
|
||||
mock_pcp_group.world_size = self.impl.pcp_size
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
|
||||
mock_dcp.world_size = self.impl.dcp_size
|
||||
mock_dcp_group = MagicMock()
|
||||
# mock_dcp_group.world_size = self.impl.dcp_size
|
||||
mock_get_dcp_group.return_value = mock_dcp_group
|
||||
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
|
||||
self.impl.dcp_size * num_heads,
|
||||
head_dim)
|
||||
out = self.impl._npu_attention_update(attn_out_lse)
|
||||
out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse)
|
||||
self.impl.dcp_size = 1
|
||||
self.impl.pcp_size = 1
|
||||
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
||||
@@ -873,8 +891,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
decode_meta = MagicMock()
|
||||
decode_meta.batch_seq_mask = batch_seq_mask
|
||||
|
||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_meta)
|
||||
result = _process_attn_out_lse(attn_output, softmax_lse,
|
||||
batch_seq_mask)
|
||||
# [PCP * S, DCP * H, D + 1]
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
|
||||
|
||||
Reference in New Issue
Block a user