[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)
### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -165,7 +165,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
self.proposer.load_model(mock_model)
|
||||
mock_get_model.assert_called_once()
|
||||
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
|
||||
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
|
||||
self.assertIs(self.proposer.model.model.embed_tokens,
|
||||
mock_model.model.embed_tokens)
|
||||
|
||||
@@ -196,7 +196,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
||||
mock_model.model.embed_tokens)
|
||||
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
|
||||
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||
@@ -239,6 +239,8 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.vllm_config.speculative_config.num_speculative_tokens = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.runner = MagicMock()
|
||||
self.runner.pcp_size = 1
|
||||
self.runner.dcp_size = 1
|
||||
|
||||
self.vllm_config.cache_config.block_size = 16
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
||||
@@ -246,6 +248,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
self.vllm_config.model_config.uses_mrope = False
|
||||
self.vllm_config.model_config.use_mla = False
|
||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||
(i + 1) * (0, ) for i in range(4)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user