[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.



version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-28 10:35:07 +08:00
committed by GitHub
parent 24328aaf00
commit dbe4c338f2
10 changed files with 167 additions and 224 deletions

View File

@@ -1,5 +1,5 @@
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
assert builder.device == device
assert builder.vllm_config == vllm_config
def test_ascend_sfa_metadata_builder_build(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
model=model,
)
assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100
model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))
attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
model=model,
)
assert isinstance(attn_metadata, AscendSFAMetadata)