[Refactor][EAGLE] 2/N: load model and generate token (#5437)

### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file

2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut and tests

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-01-05 14:07:54 +08:00
committed by GitHub
parent 50e7934415
commit 52863c4165
8 changed files with 229 additions and 609 deletions

View File

@@ -144,9 +144,17 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
mock_get_layers):
mock_pp_group.return_value.world_size = 1
mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()}
mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_target_layer2 = MagicMock()
mock_draft_layer1 = MagicMock()
mock_draft_layer3 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1,
"layer2": mock_target_layer2
}, {}, {}, {
"layer1": mock_draft_layer1,
"layer3": mock_draft_layer3
}]
mock_model = MagicMock()
mock_model.model.embed_tokens = MagicMock()
@@ -158,7 +166,7 @@ class TestEagleProposerLoadModel(TestBase):
self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_name, "layer3")
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
@@ -169,9 +177,14 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
mock_get_layers):
mock_pp_group.return_value.world_size = 2
mock_target_layers = {"layer1": MagicMock()}
mock_draft_layers = {"layer2": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_model = MagicMock()
original_embed = MagicMock()
@@ -184,7 +197,7 @@ class TestEagleProposerLoadModel(TestBase):
self.assertIsNot(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_name, "layer2")
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
@patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@@ -200,9 +213,14 @@ class TestEagleProposerLoadModel(TestBase):
mock_get_model.return_value = MagicMock(model=MagicMock(
embed_tokens=original_embed))
mock_target_layers = {"layer1": MagicMock()}
mock_draft_layers = {"layer2": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_pp_group.return_value.world_size = 2
self.proposer.model = MagicMock()
@@ -307,83 +325,6 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.use_cuda_graph = last_use_cuda_graph
class TestEagleProposerGenerateTokenIds(TestBase):
def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock()
self.vllm_config.speculative_config.method = "eagle"
self.device = torch.device("cpu")
self.runner = MagicMock()
self.runner.input_batch = MagicMock()
self.runner.input_batch.req_ids = [0, 1, 2]
self.runner.requests = {
0: MagicMock(get_token_id=lambda x: 100),
1: MagicMock(get_token_id=lambda x: 101),
2: MagicMock(get_token_id=lambda x: 102),
}
self.runner.pcp_size = 1
self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
self.vllm_config.scheduler_config.max_num_seqs = 32
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch(
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
)
self.mock_supports_multimodal_inputs.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.proposer.attn_layer_name = "layer_0"
self.proposer._propose = MagicMock(
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
def tearDown(self):
self.mock_cpugpubuffer.stop()
self.mock_supports_multimodal_inputs.stop()
# TODO: This is equivalent to disable_padded_drafter_batch=True.
# We need to add some cases about disable_padded_drafter_batch=False in future.
def test_generate_token_ids(self):
valid_sampled = [[20, 30, 40]]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = [2, 1, 3]
positions = torch.tensor([0, 1, 2, 3, 4, 5])
hidden_states = torch.randn(6, 4096)
num_scheduled = 6
mock_attn_metadata = MagicMock()
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6])
mock_attn_metadata.block_tables = MagicMock()
self.proposer._get_eagle_atten_dict = MagicMock(
return_value={"layer_0": mock_attn_metadata})
result = self.proposer.generate_token_ids(
sampled_token_ids=valid_sampled,
scheduler_output=scheduler_output,
positions=positions,
num_scheduled_tokens=num_scheduled,
hidden_states=hidden_states,
)
self.proposer._propose.assert_called_once()
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
class TestEagleProposerHelperMethods(TestBase):
# TODO: Can add some tests about prepare_next_token_ids in future.

View File

@@ -6,12 +6,8 @@ import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
@@ -107,53 +103,6 @@ class TestMtpProposer:
assert proposer.use_aclgraph is True
@patch("vllm.config.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
@patch(
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
mock_set_dtype, mock_process_weights, mock_get_loader,
mock_get_layers, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
attn_layers_all = {
"target_attn_layer": "val0",
"draft_attn_layer": "val1",
"draft_attn_exclude_by_indexer": "val2",
}
indexer_layers_all = {
"target_indexer_0": "val3",
"draft_attn_exclude_by_indexer": "val4"
}
def get_layers_side_effect(vllm_config, cache_cls):
if cache_cls == AttentionLayerBase:
return attn_layers_all
elif cache_cls == DeepseekV32IndexerCache:
return indexer_layers_all
else:
return {}
# Setup
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer._init_mtp_model = MagicMock()
mock_model = MagicMock()
proposer.model = mock_model
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
mock_loader.get_all_weights.return_value = {
"dummy_weight": torch.tensor([1.0])
}
mock_get_layers.side_effect = get_layers_side_effect
with pytest.raises(AssertionError):
proposer.load_model(mock_model)
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
@@ -209,78 +158,6 @@ class TestMtpProposer:
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_deps = MagicMock()
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
mock_deps.scheduler_output.num_scheduled_tokens = 16
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_deps.spec_decode_metadata.num_draft_tokens = 2
mock_deps.runner = MagicMock()
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.dcp_size = 1
mock_deps.runner.pcp_manager = MagicMock()
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
32,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
torch.arange(32, dtype=torch.int32)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
5,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
torch.tensor([0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
[200, -1, -1],
[300, 301, 302]])
proposer = MagicMock(spec=MtpProposer)
proposer.enable_shared_expert_dp = False
proposer.runner = mock_deps.runner
proposer.decode_threshold = 1
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer.dcp_size = mock_deps.runner.dcp_size
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
torch.tensor([7, 15, 23])))
proposer._propose = MagicMock(
return_value=torch.tensor([400, 401, 402]))
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
proposer)
draft_token_ids = proposer.generate_token_ids(
sampled_token_ids=mock_deps.sampled_token_ids,
scheduler_output=mock_deps.scheduler_output,
spec_decode_metadata=mock_deps.spec_decode_metadata,
positions=mock_deps.positions,
num_scheduled_tokens=mock_deps.scheduler_output.
num_scheduled_tokens,
hidden_states=mock_deps.hidden_states,
)
proposer.prepare_next_token_ids_padded.assert_called_once()
proposer.prepare_inputs_padded.assert_called_once()
proposer._propose.assert_called_once()
assert torch.equal(draft_token_ids, proposer._propose.return_value)
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()