[Refactor]5/N Extract common code of mla_v1.py & extract mla_cp (#5097)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: The functions related to Cp differ significantly from those of normal MLA-Attention, but the coupling is quite severe. Steps: 1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: build_prefill_metadata, build_decode_metadata,build_cp_metadata, build_chunked_metadata todo: 1)refactor function _compute_prefill_context; 2)refactor function _mla_preprocess,_mla_decode_preprocess 3)Extract public data and processing functions from the attention_cp.py and mla_cp.py files to the common_cp file. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: 0.13.0rc3 - vLLM main:ad32e3e19c--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -6,8 +6,9 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
|
||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
||||
|
||||
|
||||
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
|
||||
@@ -127,7 +128,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
chunked_context_metadata = CPChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=local_chunk_starts.to(non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
@@ -144,16 +145,15 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = (
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=chunk_starts.to(non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
workspace=None,
|
||||
))
|
||||
chunked_context_metadata = (ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=chunk_starts.to(non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
workspace=None,
|
||||
))
|
||||
return chunked_context_metadata
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
||||
AscendMLADecodeMetadata,
|
||||
AscendMLAImpl, AscendMLAMetadata,
|
||||
AscendMLAMetadataBuilder,
|
||||
AscendMLAPrefillMetadata)
|
||||
AscendMLAPrefillMetadata,
|
||||
ChunkedContextMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
|
||||
|
||||
@@ -76,27 +77,15 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
||||
max_seq_lens = [2, 2]
|
||||
workspace = torch.randn(2, 4)
|
||||
chunk_seq_lens = torch.tensor([2, 2])
|
||||
padded_chunk_seq_lens_npu = torch.tensor([2, 2])
|
||||
padded_local_chunk_seq_lens = [[2], [2]]
|
||||
local_context_lens_allranks = [[1, 1], [1, 1]]
|
||||
padded_local_cu_seq_lens = torch.tensor([0, 2, 4])
|
||||
cu_seq_lens_lst = [[0, 2], [2, 4]]
|
||||
chunk_size = 2
|
||||
|
||||
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
chunked_context = ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
starts=starts,
|
||||
seq_tot=seq_tot,
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu,
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens,
|
||||
local_context_lens_allranks=local_context_lens_allranks,
|
||||
padded_local_cu_seq_lens=padded_local_cu_seq_lens,
|
||||
cu_seq_lens_lst=cu_seq_lens_lst,
|
||||
chunk_size=chunk_size)
|
||||
chunk_seq_lens_npu=chunk_seq_lens)
|
||||
|
||||
metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
@@ -119,17 +108,6 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
||||
chunk_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu,
|
||||
padded_chunk_seq_lens_npu)
|
||||
self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens,
|
||||
padded_local_chunk_seq_lens)
|
||||
self.assertEqual(metadata.chunked_context.local_context_lens_allranks,
|
||||
local_context_lens_allranks)
|
||||
self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens,
|
||||
padded_local_cu_seq_lens)
|
||||
self.assertEqual(metadata.chunked_context.cu_seq_lens_lst,
|
||||
cu_seq_lens_lst)
|
||||
self.assertEqual(metadata.chunked_context.chunk_size, chunk_size)
|
||||
|
||||
|
||||
class TestAscendMLADecodeMetadata(TestBase):
|
||||
@@ -218,11 +196,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
|
||||
mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -262,8 +238,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
self.assertEqual(builder.dcp_size, mock_dcp.world_size)
|
||||
self.assertEqual(builder.pcp_size, mock_pcp.world_size)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
@@ -271,10 +245,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
@@ -324,11 +295,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -387,10 +355,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
ascend_config = MagicMock()
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
@@ -448,10 +414,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
@@ -496,11 +459,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
@@ -566,17 +526,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.head_size = 128
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_get_ascend_config,
|
||||
mock_dcp_world_size,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
@@ -633,17 +590,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_get_ascend_config,
|
||||
mock_dcp_world_size,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
@@ -701,13 +655,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_decode_only_metadata(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size,
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -757,13 +708,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size,
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -814,13 +762,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size,
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -868,16 +813,10 @@ class TestAscendMLAImpl(TestBase):
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user