[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)
### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -165,7 +165,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
self.proposer.load_model(mock_model)
|
||||
mock_get_model.assert_called_once()
|
||||
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
|
||||
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
|
||||
self.assertIs(self.proposer.model.model.embed_tokens,
|
||||
mock_model.model.embed_tokens)
|
||||
|
||||
@@ -196,7 +196,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
||||
mock_model.model.embed_tokens)
|
||||
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
|
||||
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||
@@ -239,6 +239,8 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.vllm_config.speculative_config.num_speculative_tokens = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.runner = MagicMock()
|
||||
self.runner.pcp_size = 1
|
||||
self.runner.dcp_size = 1
|
||||
|
||||
self.vllm_config.cache_config.block_size = 16
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
||||
@@ -246,6 +248,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
self.vllm_config.model_config.uses_mrope = False
|
||||
self.vllm_config.model_config.use_mla = False
|
||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||
(i + 1) * (0, ) for i in range(4)
|
||||
])
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestMtpProposer:
|
||||
config.additional_config = None
|
||||
config.speculative_config = MagicMock(spec=SpeculativeConfig)
|
||||
config.speculative_config.num_speculative_tokens = 2
|
||||
config.speculative_config.method = "deepseek_mtp"
|
||||
config.speculative_config.method = "mtp"
|
||||
config.speculative_config.draft_model_config = MagicMock()
|
||||
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
||||
config.speculative_config.speculative_token_tree = str([
|
||||
@@ -98,9 +98,11 @@ class TestMtpProposer:
|
||||
mock_buffer_instance = MagicMock()
|
||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||
runner._use_aclgraph.return_value = True
|
||||
vllm_config.scheduler_config.async_scheduling = False
|
||||
vllm_config.speculative_config.enforce_eager = False
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
|
||||
assert proposer.use_aclgraph is True
|
||||
assert proposer.use_cuda_graph is True
|
||||
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||
|
||||
Reference in New Issue
Block a user