[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
assert builder.device == device
|
||||
assert builder.vllm_config == vllm_config
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=10,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(metadata, AscendSFAMetadata)
|
||||
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
attn_metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, AscendSFAMetadata)
|
||||
|
||||
Reference in New Issue
Block a user