[Refactor][EAGLE] 2/N: load model and generate token (#5437)

### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file

2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut and tests

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-01-05 14:07:54 +08:00
committed by GitHub
parent 50e7934415
commit 52863c4165
8 changed files with 229 additions and 609 deletions

View File

@@ -9,7 +9,6 @@ patch
ModelRunner_prepare_inputs
disaggregated_prefill
eplb_swift_balancer.md
Multi_Token_Prediction
ACL_Graph
KV_Cache_Pool_Guide
add_custom_aclnn_op

View File

@@ -12,6 +12,7 @@ structured_output
lora
eplb_swift_balancer
netloader
Multi_Token_Prediction
dynamic_batch
kv_pool
external_dp

View File

@@ -144,9 +144,17 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
mock_get_layers):
mock_pp_group.return_value.world_size = 1
mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()}
mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_target_layer2 = MagicMock()
mock_draft_layer1 = MagicMock()
mock_draft_layer3 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1,
"layer2": mock_target_layer2
}, {}, {}, {
"layer1": mock_draft_layer1,
"layer3": mock_draft_layer3
}]
mock_model = MagicMock()
mock_model.model.embed_tokens = MagicMock()
@@ -158,7 +166,7 @@ class TestEagleProposerLoadModel(TestBase):
self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_name, "layer3")
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
@@ -169,9 +177,14 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
mock_get_layers):
mock_pp_group.return_value.world_size = 2
mock_target_layers = {"layer1": MagicMock()}
mock_draft_layers = {"layer2": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_model = MagicMock()
original_embed = MagicMock()
@@ -184,7 +197,7 @@ class TestEagleProposerLoadModel(TestBase):
self.assertIsNot(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_name, "layer2")
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
@patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@@ -200,9 +213,14 @@ class TestEagleProposerLoadModel(TestBase):
mock_get_model.return_value = MagicMock(model=MagicMock(
embed_tokens=original_embed))
mock_target_layers = {"layer1": MagicMock()}
mock_draft_layers = {"layer2": MagicMock()}
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_pp_group.return_value.world_size = 2
self.proposer.model = MagicMock()
@@ -307,83 +325,6 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.use_cuda_graph = last_use_cuda_graph
class TestEagleProposerGenerateTokenIds(TestBase):
def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock()
self.vllm_config.speculative_config.method = "eagle"
self.device = torch.device("cpu")
self.runner = MagicMock()
self.runner.input_batch = MagicMock()
self.runner.input_batch.req_ids = [0, 1, 2]
self.runner.requests = {
0: MagicMock(get_token_id=lambda x: 100),
1: MagicMock(get_token_id=lambda x: 101),
2: MagicMock(get_token_id=lambda x: 102),
}
self.runner.pcp_size = 1
self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
self.vllm_config.scheduler_config.max_num_seqs = 32
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch(
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
)
self.mock_supports_multimodal_inputs.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.proposer.attn_layer_name = "layer_0"
self.proposer._propose = MagicMock(
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
def tearDown(self):
self.mock_cpugpubuffer.stop()
self.mock_supports_multimodal_inputs.stop()
# TODO: This is equivalent to disable_padded_drafter_batch=True.
# We need to add some cases about disable_padded_drafter_batch=False in future.
def test_generate_token_ids(self):
valid_sampled = [[20, 30, 40]]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = [2, 1, 3]
positions = torch.tensor([0, 1, 2, 3, 4, 5])
hidden_states = torch.randn(6, 4096)
num_scheduled = 6
mock_attn_metadata = MagicMock()
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6])
mock_attn_metadata.block_tables = MagicMock()
self.proposer._get_eagle_atten_dict = MagicMock(
return_value={"layer_0": mock_attn_metadata})
result = self.proposer.generate_token_ids(
sampled_token_ids=valid_sampled,
scheduler_output=scheduler_output,
positions=positions,
num_scheduled_tokens=num_scheduled,
hidden_states=hidden_states,
)
self.proposer._propose.assert_called_once()
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
class TestEagleProposerHelperMethods(TestBase):
# TODO: Can add some tests about prepare_next_token_ids in future.

View File

@@ -6,12 +6,8 @@ import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
@@ -107,53 +103,6 @@ class TestMtpProposer:
assert proposer.use_aclgraph is True
@patch("vllm.config.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
@patch(
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
mock_set_dtype, mock_process_weights, mock_get_loader,
mock_get_layers, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
attn_layers_all = {
"target_attn_layer": "val0",
"draft_attn_layer": "val1",
"draft_attn_exclude_by_indexer": "val2",
}
indexer_layers_all = {
"target_indexer_0": "val3",
"draft_attn_exclude_by_indexer": "val4"
}
def get_layers_side_effect(vllm_config, cache_cls):
if cache_cls == AttentionLayerBase:
return attn_layers_all
elif cache_cls == DeepseekV32IndexerCache:
return indexer_layers_all
else:
return {}
# Setup
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer._init_mtp_model = MagicMock()
mock_model = MagicMock()
proposer.model = mock_model
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
mock_loader.get_all_weights.return_value = {
"dummy_weight": torch.tensor([1.0])
}
mock_get_layers.side_effect = get_layers_side_effect
with pytest.raises(AssertionError):
proposer.load_model(mock_model)
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
@@ -209,78 +158,6 @@ class TestMtpProposer:
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_deps = MagicMock()
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
mock_deps.scheduler_output.num_scheduled_tokens = 16
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_deps.spec_decode_metadata.num_draft_tokens = 2
mock_deps.runner = MagicMock()
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.dcp_size = 1
mock_deps.runner.pcp_manager = MagicMock()
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
32,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
torch.arange(32, dtype=torch.int32)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
5,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
torch.tensor([0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
[200, -1, -1],
[300, 301, 302]])
proposer = MagicMock(spec=MtpProposer)
proposer.enable_shared_expert_dp = False
proposer.runner = mock_deps.runner
proposer.decode_threshold = 1
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer.dcp_size = mock_deps.runner.dcp_size
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
torch.tensor([7, 15, 23])))
proposer._propose = MagicMock(
return_value=torch.tensor([400, 401, 402]))
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
proposer)
draft_token_ids = proposer.generate_token_ids(
sampled_token_ids=mock_deps.sampled_token_ids,
scheduler_output=mock_deps.scheduler_output,
spec_decode_metadata=mock_deps.spec_decode_metadata,
positions=mock_deps.positions,
num_scheduled_tokens=mock_deps.scheduler_output.
num_scheduled_tokens,
hidden_states=mock_deps.hidden_states,
)
proposer.prepare_next_token_ids_padded.assert_called_once()
proposer.prepare_inputs_padded.assert_called_once()
proposer._propose.assert_called_once()
assert torch.equal(draft_token_ids, proposer._propose.return_value)
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()

View File

@@ -4,7 +4,6 @@ from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
@@ -13,6 +12,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -109,25 +109,54 @@ class EagleProposer(VllmEagleProposer):
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
self.attn_layer_name = next(iter(draft_attn_layer_names))
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache).keys()
draft_attn_layer = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys()
draft_attn_layer_names = draft_attn_layer - target_attn_layer_names
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
if self.method == "mtp":
if self.vllm_config.model_config.is_deepseek_mla and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
logger.info(
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else:
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
@@ -141,6 +170,13 @@ class EagleProposer(VllmEagleProposer):
else:
self.model.lm_head = model.lm_head
if self.method == "mtp" and \
self.vllm_config.model_config.is_deepseek_mla:
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
model.lm_head.weight):
layer_module.shared_head.head = model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and self.use_cuda_graph:
self.update_stream = torch.npu.Stream()
@@ -205,7 +241,7 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
attn_metadata = {}
for layer_name in [self.attn_layer_name]:
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_eagle
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
@@ -235,135 +271,6 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
def _propose(
self,
# [num_tokens]
@@ -430,9 +337,11 @@ class EagleProposer(VllmEagleProposer):
self.runner.get_model())
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_name:
per_layer_attn_metadata[layer_name] = attn_metadata
with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata},
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_actual_tokens=num_tokens,
@@ -558,7 +467,7 @@ class EagleProposer(VllmEagleProposer):
# Run the model.
with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata},
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
num_actual_tokens=batch_size,
@@ -696,28 +605,6 @@ class EagleProposer(VllmEagleProposer):
return next_token_ids, valid_sampled_tokens_count
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -1,25 +1,16 @@
import importlib
from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.config import CUDAGraphMode
from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
@@ -54,15 +45,6 @@ _MTP_MODELS = {
}
def _load_model(architecture):
if architecture not in _MTP_MODELS:
raise ValueError("Invalid architecture for mtp.")
module_name, model_name = _MTP_MODELS[architecture]
module = importlib.import_module(module_name)
model = getattr(module, model_name)
return model
class MtpProposer(EagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing?
@@ -86,64 +68,6 @@ class MtpProposer(EagleProposer):
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache)
draft_indexer_layer_names = indexer_layers.keys(
) - target_indexer_layer_names
# NOTE: Currently we don't have specific attention backend and attention metadata
# for deepseek v3.2 indexer, so we just exclude the indexer layers here.
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if get_pp_group().world_size == 1:
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):
self.update_stream: torch.npu.Stream = torch.npu.Stream()
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
@@ -256,153 +180,6 @@ class MtpProposer(EagleProposer):
if with_prefill:
break
def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch:
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self._prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)
def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -1,6 +1,6 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2025 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,6 +54,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
@@ -113,7 +114,6 @@ from vllm_ascend.worker.pcp_utils import PCPManager
from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -1257,6 +1257,7 @@ class NPUModelRunner(GPUModelRunner):
logits_indices=logits_indices,
)
# TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community.
def propose_draft_token_ids(
self,
valid_sampled_token_ids: torch.Tensor | list[list[int]],
@@ -1273,10 +1274,147 @@ class NPUModelRunner(GPUModelRunner):
# Speculative decoding is not enabled.
draft_token_ids = None
else:
draft_token_ids = self.drafter.generate_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, aux_hidden_states)
if self.speculative_config.method in ("suffix", "ngram"):
draft_token_ids = self.drafter.generate_token_ids(
valid_sampled_token_ids, sampling_metadata,
scheduler_output, spec_decode_metadata, positions,
num_scheduled_tokens, hidden_states, aux_hidden_states)
elif self.speculative_config.use_eagle():
common_attn_metadata = self.spec_decode_common_attn_metadata
sampled_token_ids = valid_sampled_token_ids
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
assert self.drafter is not None
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
sampled_token_ids, self.requests, self.input_batch,
scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
assert self.drafter is not None
next_token_ids, valid_sampled_tokens_count = \
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.long_seq_metadata # type: ignore
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None # type: ignore
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([
h[:num_scheduled_tokens]
for h in aux_hidden_states
],
dim=-1)
else:
target_hidden_states = hidden_states[:
num_scheduled_tokens]
else:
if self.pcp_size > 1:
assert common_attn_metadata is not None
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
assert common_attn_metadata is not None
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
assert self.drafter is not None
common_attn_metadata, token_indices =\
self.drafter.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
assert self.drafter is not None
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
assert self.drafter is not None
draft_token_ids = self.drafter._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
return draft_token_ids
@staticmethod