[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)

### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-01-15 10:24:35 +08:00
committed by GitHub
parent c11a05c4e1
commit ea01aeaab7
4 changed files with 109 additions and 123 deletions

View File

@@ -165,7 +165,7 @@ class TestEagleProposerLoadModel(TestBase):
self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
@@ -196,7 +196,7 @@ class TestEagleProposerLoadModel(TestBase):
self.assertIsNot(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
@patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@@ -239,6 +239,8 @@ class TestEagleProposerDummyRun(TestBase):
self.vllm_config.speculative_config.num_speculative_tokens = 4
self.device = torch.device("cpu")
self.runner = MagicMock()
self.runner.pcp_size = 1
self.runner.dcp_size = 1
self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
@@ -246,6 +248,7 @@ class TestEagleProposerDummyRun(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.model_config.use_mla = False
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(4)
])

View File

@@ -30,7 +30,7 @@ class TestMtpProposer:
config.additional_config = None
config.speculative_config = MagicMock(spec=SpeculativeConfig)
config.speculative_config.num_speculative_tokens = 2
config.speculative_config.method = "deepseek_mtp"
config.speculative_config.method = "mtp"
config.speculative_config.draft_model_config = MagicMock()
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
config.speculative_config.speculative_token_tree = str([
@@ -98,9 +98,11 @@ class TestMtpProposer:
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
runner._use_aclgraph.return_value = True
vllm_config.scheduler_config.async_scheduling = False
vllm_config.speculative_config.enforce_eager = False
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.use_aclgraph is True
assert proposer.use_cuda_graph is True
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")