[Refactor] cache cos/sin in mla & remove parameter model in builder. (#5277)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Cache cos/sin in mla
2. AttentionBuilder inherits from the original class of vllm.
version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -289,6 +289,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@@ -296,7 +297,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -330,7 +332,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
common_metadata = MagicMock()
|
||||
model = MagicMock()
|
||||
common_metadata.graph_pad_size = 8
|
||||
common_metadata.num_reqs = 4
|
||||
common_metadata.num_actual_tokens = 5
|
||||
@@ -343,7 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
|
||||
common_metadata.block_table_tensor = block_table
|
||||
common_metadata.prefill_context_parallel_metadata = None
|
||||
metadata = builder.build(0, common_metadata, model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
|
||||
torch.Tensor([6, 6]))
|
||||
metadata = builder.build(0, common_metadata)
|
||||
|
||||
self.assertEqual(metadata.decode.actual_seq_lengths_q,
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
@@ -526,6 +529,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.head_size = 128
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@@ -534,7 +538,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -579,9 +584,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -590,6 +595,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@@ -598,7 +604,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
@@ -644,9 +651,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -655,11 +662,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
@@ -697,9 +706,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
||||
torch.Tensor([10, 10]))
|
||||
metadata = builder.build(1, common_attn_metadata)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -708,11 +717,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
@@ -750,10 +761,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
|
||||
torch.Tensor([10, 10]))
|
||||
metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
@@ -762,11 +773,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
@@ -795,13 +808,11 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
|
||||
torch.Tensor(10))
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.PrefillNoCache,
|
||||
mock_model)
|
||||
common_attn_metadata, AscendAttentionState.PrefillNoCache)
|
||||
self.assertIn(
|
||||
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
|
||||
str(ctx.exception))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -102,7 +102,8 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
assert builder.device == device
|
||||
assert builder.vllm_config == vllm_config
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -133,21 +134,21 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=10,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(metadata, AscendSFAMetadata)
|
||||
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -178,14 +179,12 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
common_attn_metadata.sin = None
|
||||
common_attn_metadata.num_input_tokens = 100
|
||||
|
||||
model = MagicMock()
|
||||
model.model.layers = [MagicMock() for _ in range(10)]
|
||||
model.model.start_layer = 0
|
||||
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
||||
torch.randn(100))
|
||||
|
||||
attn_metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, AscendSFAMetadata)
|
||||
|
||||
Reference in New Issue
Block a user