[Feat][Graph] Support MTP for ACL Graph (#2932)
### What this PR does / why we need it?
This PR depends on the merge of #2707 and has adapted the aclgraph
functionality to support MTP.
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
2b85697031
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -186,6 +186,34 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 3
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
|
||||
ascend_config = MagicMock()
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
@@ -208,6 +236,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
|
||||
Reference in New Issue
Block a user