[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -289,6 +289,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@@ -296,7 +297,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -330,7 +332,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
common_metadata = MagicMock()
model = MagicMock()
common_metadata.graph_pad_size = 8
common_metadata.num_reqs = 4
common_metadata.num_actual_tokens = 5
@@ -343,7 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
common_metadata.block_table_tensor = block_table
common_metadata.prefill_context_parallel_metadata = None
metadata = builder.build(0, common_metadata, model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
torch.Tensor([6, 6]))
metadata = builder.build(0, common_metadata)
self.assertEqual(metadata.decode.actual_seq_lengths_q,
[1, 2, 4, 5, 6, 6, 7, 8])
@@ -526,6 +529,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.head_size = 128
self.kv_cache_spec.num_heads = 32
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@@ -534,7 +538,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -579,9 +584,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
@@ -590,6 +595,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@@ -598,7 +604,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -644,9 +651,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
@@ -655,11 +662,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_decode_only_metadata(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -697,9 +706,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
torch.Tensor([10, 10]))
metadata = builder.build(1, common_attn_metadata)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
@@ -708,11 +717,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -750,10 +761,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
torch.Tensor([10, 10]))
metadata = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
common_attn_metadata, AscendAttentionState.DecodeOnly)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
@@ -762,11 +773,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
@@ -795,13 +808,11 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
with self.assertRaises(NotImplementedError) as ctx:
builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.PrefillNoCache,
mock_model)
common_attn_metadata, AscendAttentionState.PrefillNoCache)
self.assertIn(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
str(ctx.exception))

View File

@@ -1,5 +1,5 @@
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
assert builder.device == device
assert builder.vllm_config == vllm_config
def test_ascend_sfa_metadata_builder_build(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
model=model,
)
assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
model=model,
)
assert isinstance(attn_metadata, AscendSFAMetadata)