[Refactor][EAGLE] 2/N: load model and generate token (#5437)

### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file

2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut and tests

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-01-05 14:07:54 +08:00
committed by GitHub
parent 50e7934415
commit 52863c4165
8 changed files with 229 additions and 609 deletions

View File

@@ -9,7 +9,6 @@ patch
ModelRunner_prepare_inputs ModelRunner_prepare_inputs
disaggregated_prefill disaggregated_prefill
eplb_swift_balancer.md eplb_swift_balancer.md
Multi_Token_Prediction
ACL_Graph ACL_Graph
KV_Cache_Pool_Guide KV_Cache_Pool_Guide
add_custom_aclnn_op add_custom_aclnn_op

View File

@@ -12,6 +12,7 @@ structured_output
lora lora
eplb_swift_balancer eplb_swift_balancer
netloader netloader
Multi_Token_Prediction
dynamic_batch dynamic_batch
kv_pool kv_pool
external_dp external_dp

View File

@@ -144,9 +144,17 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp1(self, mock_pp_group, mock_get_model, def test_load_model_pp1(self, mock_pp_group, mock_get_model,
mock_get_layers): mock_get_layers):
mock_pp_group.return_value.world_size = 1 mock_pp_group.return_value.world_size = 1
mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()} mock_target_layer1 = MagicMock()
mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()} mock_target_layer2 = MagicMock()
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] mock_draft_layer1 = MagicMock()
mock_draft_layer3 = MagicMock()
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1,
"layer2": mock_target_layer2
}, {}, {}, {
"layer1": mock_draft_layer1,
"layer3": mock_draft_layer3
}]
mock_model = MagicMock() mock_model = MagicMock()
mock_model.model.embed_tokens = MagicMock() mock_model.model.embed_tokens = MagicMock()
@@ -158,7 +166,7 @@ class TestEagleProposerLoadModel(TestBase):
self.proposer.load_model(mock_model) self.proposer.load_model(mock_model)
mock_get_model.assert_called_once() mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_name, "layer3") self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens, self.assertIs(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens) mock_model.model.embed_tokens)
@@ -169,9 +177,14 @@ class TestEagleProposerLoadModel(TestBase):
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
mock_get_layers): mock_get_layers):
mock_pp_group.return_value.world_size = 2 mock_pp_group.return_value.world_size = 2
mock_target_layers = {"layer1": MagicMock()} mock_target_layer1 = MagicMock()
mock_draft_layers = {"layer2": MagicMock()} mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_model = MagicMock() mock_model = MagicMock()
original_embed = MagicMock() original_embed = MagicMock()
@@ -184,7 +197,7 @@ class TestEagleProposerLoadModel(TestBase):
self.assertIsNot(self.proposer.model.model.embed_tokens, self.assertIsNot(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens) mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_name, "layer2") self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
@patch( @patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@@ -200,9 +213,14 @@ class TestEagleProposerLoadModel(TestBase):
mock_get_model.return_value = MagicMock(model=MagicMock( mock_get_model.return_value = MagicMock(model=MagicMock(
embed_tokens=original_embed)) embed_tokens=original_embed))
mock_target_layers = {"layer1": MagicMock()} mock_target_layer1 = MagicMock()
mock_draft_layers = {"layer2": MagicMock()} mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
mock_get_layers.side_effect = [{
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_pp_group.return_value.world_size = 2 mock_pp_group.return_value.world_size = 2
self.proposer.model = MagicMock() self.proposer.model = MagicMock()
@@ -307,83 +325,6 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
class TestEagleProposerGenerateTokenIds(TestBase):
def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock()
self.vllm_config.speculative_config.method = "eagle"
self.device = torch.device("cpu")
self.runner = MagicMock()
self.runner.input_batch = MagicMock()
self.runner.input_batch.req_ids = [0, 1, 2]
self.runner.requests = {
0: MagicMock(get_token_id=lambda x: 100),
1: MagicMock(get_token_id=lambda x: 101),
2: MagicMock(get_token_id=lambda x: 102),
}
self.runner.pcp_size = 1
self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
self.vllm_config.scheduler_config.max_num_seqs = 32
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch(
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
)
self.mock_supports_multimodal_inputs.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.proposer.attn_layer_name = "layer_0"
self.proposer._propose = MagicMock(
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
def tearDown(self):
self.mock_cpugpubuffer.stop()
self.mock_supports_multimodal_inputs.stop()
# TODO: This is equivalent to disable_padded_drafter_batch=True.
# We need to add some cases about disable_padded_drafter_batch=False in future.
def test_generate_token_ids(self):
valid_sampled = [[20, 30, 40]]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = [2, 1, 3]
positions = torch.tensor([0, 1, 2, 3, 4, 5])
hidden_states = torch.randn(6, 4096)
num_scheduled = 6
mock_attn_metadata = MagicMock()
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6])
mock_attn_metadata.block_tables = MagicMock()
self.proposer._get_eagle_atten_dict = MagicMock(
return_value={"layer_0": mock_attn_metadata})
result = self.proposer.generate_token_ids(
sampled_token_ids=valid_sampled,
scheduler_output=scheduler_output,
positions=positions,
num_scheduled_tokens=num_scheduled,
hidden_states=hidden_states,
)
self.proposer._propose.assert_called_once()
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
class TestEagleProposerHelperMethods(TestBase): class TestEagleProposerHelperMethods(TestBase):
# TODO: Can add some tests about prepare_next_token_ids in future. # TODO: Can add some tests about prepare_next_token_ids in future.

View File

@@ -6,12 +6,8 @@ import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode, from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig, ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ascend_config import init_ascend_config
@@ -107,53 +103,6 @@ class TestMtpProposer:
assert proposer.use_aclgraph is True assert proposer.use_aclgraph is True
@patch("vllm.config.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
@patch(
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
mock_set_dtype, mock_process_weights, mock_get_loader,
mock_get_layers, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
attn_layers_all = {
"target_attn_layer": "val0",
"draft_attn_layer": "val1",
"draft_attn_exclude_by_indexer": "val2",
}
indexer_layers_all = {
"target_indexer_0": "val3",
"draft_attn_exclude_by_indexer": "val4"
}
def get_layers_side_effect(vllm_config, cache_cls):
if cache_cls == AttentionLayerBase:
return attn_layers_all
elif cache_cls == DeepseekV32IndexerCache:
return indexer_layers_all
else:
return {}
# Setup
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer._init_mtp_model = MagicMock()
mock_model = MagicMock()
proposer.model = mock_model
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
mock_loader.get_all_weights.return_value = {
"dummy_weight": torch.tensor([1.0])
}
mock_get_layers.side_effect = get_layers_side_effect
with pytest.raises(AssertionError):
proposer.load_model(mock_model)
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
@@ -209,78 +158,6 @@ class TestMtpProposer:
# Check that model was called correct number of times # Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_deps = MagicMock()
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
mock_deps.scheduler_output.num_scheduled_tokens = 16
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_deps.spec_decode_metadata.num_draft_tokens = 2
mock_deps.runner = MagicMock()
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.dcp_size = 1
mock_deps.runner.pcp_manager = MagicMock()
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
32,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
torch.arange(32, dtype=torch.int32)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
5,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
torch.tensor([0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
[200, -1, -1],
[300, 301, 302]])
proposer = MagicMock(spec=MtpProposer)
proposer.enable_shared_expert_dp = False
proposer.runner = mock_deps.runner
proposer.decode_threshold = 1
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer.dcp_size = mock_deps.runner.dcp_size
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
torch.tensor([7, 15, 23])))
proposer._propose = MagicMock(
return_value=torch.tensor([400, 401, 402]))
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
proposer)
draft_token_ids = proposer.generate_token_ids(
sampled_token_ids=mock_deps.sampled_token_ids,
scheduler_output=mock_deps.scheduler_output,
spec_decode_metadata=mock_deps.spec_decode_metadata,
positions=mock_deps.positions,
num_scheduled_tokens=mock_deps.scheduler_output.
num_scheduled_tokens,
hidden_states=mock_deps.hidden_states,
)
proposer.prepare_next_token_ids_padded.assert_called_once()
proposer.prepare_inputs_padded.assert_called_once()
proposer._propose.assert_called_once()
assert torch.equal(draft_token_ids, proposer._propose.return_value)
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer): def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock() mock_buffer_instance = MagicMock()

View File

@@ -4,7 +4,6 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
@@ -13,6 +12,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -109,25 +109,54 @@ class EagleProposer(VllmEagleProposer):
def load_model(self, model: nn.Module) -> None: def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()) get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
self.model = get_model(vllm_config=self.vllm_config, self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config. model_config=self.vllm_config.
speculative_config.draft_model_config) speculative_config.draft_model_config)
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() - indexer_layers = get_layers_from_vllm_config(
target_attn_layer_names) self.vllm_config, DeepseekV32IndexerCache).keys()
self.attn_layer_name = next(iter(draft_attn_layer_names)) draft_attn_layer = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys()
draft_attn_layer_names = draft_attn_layer - target_attn_layer_names
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
logger.info( if self.method == "mtp":
"The EAGLE head shares the same vocab embedding" \ if self.vllm_config.model_config.is_deepseek_mla and \
" with the target model." torch.equal(self.model.model.embed_tokens.weight,
) model.model.embed_tokens.weight):
self.model.model.embed_tokens = model.model.embed_tokens # If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
logger.info(
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else: else:
logger.info( logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \ "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
" weights instead of sharing them with the target model." " weights instead of sharing them with the target model."
) )
@@ -141,6 +170,13 @@ class EagleProposer(VllmEagleProposer):
else: else:
self.model.lm_head = model.lm_head self.model.lm_head = model.lm_head
if self.method == "mtp" and \
self.vllm_config.model_config.is_deepseek_mla:
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
model.lm_head.weight):
layer_module.shared_head.head = model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and self.use_cuda_graph: ) and self.use_cuda_graph:
self.update_stream = torch.npu.Stream() self.update_stream = torch.npu.Stream()
@@ -205,7 +241,7 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_eagle = builder.build_for_graph_capture( attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill) common_attn_metadata, AscendAttentionState.ChunkedPrefill)
attn_metadata = {} attn_metadata = {}
for layer_name in [self.attn_layer_name]: for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_eagle attn_metadata[layer_name] = attn_metadata_eagle
for i in range(self.num_speculative_tokens): for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
@@ -235,135 +271,6 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config, self.vllm_config,
) )
def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
def _propose( def _propose(
self, self,
# [num_tokens] # [num_tokens]
@@ -430,9 +337,11 @@ class EagleProposer(VllmEagleProposer):
self.runner.get_model()) self.runner.get_model())
# update global cos, sin # update global cos, sin
update_cos_sin(self.positions[:num_input_tokens]) update_cos_sin(self.positions[:num_input_tokens])
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_name:
per_layer_attn_metadata[layer_name] = attn_metadata
with set_ascend_forward_context( with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata}, per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
@@ -558,7 +467,7 @@ class EagleProposer(VllmEagleProposer):
# Run the model. # Run the model.
with set_ascend_forward_context( with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata}, per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size, num_tokens=input_batch_size,
num_actual_tokens=batch_size, num_actual_tokens=batch_size,
@@ -696,28 +605,6 @@ class EagleProposer(VllmEagleProposer):
return next_token_ids, valid_sampled_tokens_count return next_token_ids, valid_sampled_tokens_count
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def prepare_inputs( def prepare_inputs(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,

View File

@@ -1,25 +1,16 @@
import importlib
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config, from vllm.config import CUDAGraphMode
set_current_vllm_config)
from vllm.distributed import get_pcp_group from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@@ -54,15 +45,6 @@ _MTP_MODELS = {
} }
def _load_model(architecture):
if architecture not in _MTP_MODELS:
raise ValueError("Invalid architecture for mtp.")
module_name, model_name = _MTP_MODELS[architecture]
module = importlib.import_module(module_name)
model = getattr(module, model_name)
return model
class MtpProposer(EagleProposer): class MtpProposer(EagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing? # TODO: Find out why ModelRunner does not this explicit typing?
@@ -86,64 +68,6 @@ class MtpProposer(EagleProposer):
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config) num_tokens, self.vllm_config)
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache)
draft_indexer_layer_names = indexer_layers.keys(
) - target_indexer_layer_names
# NOTE: Currently we don't have specific attention backend and attention metadata
# for deepseek v3.2 indexer, so we just exclude the indexer layers here.
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if get_pp_group().world_size == 1:
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):
self.update_stream: torch.npu.Stream = torch.npu.Stream()
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, def dummy_run(self,
num_tokens: int, num_tokens: int,
@@ -256,153 +180,6 @@ class MtpProposer(EagleProposer):
if with_prefill: if with_prefill:
break break
def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch:
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self._prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)
def _prepare_inputs( def _prepare_inputs(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,

View File

@@ -1,6 +1,6 @@
# #
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team. # Copyright 2025 The vLLM team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -54,6 +54,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
@@ -113,7 +114,6 @@ from vllm_ascend.worker.pcp_utils import PCPManager
from vllm_ascend.ascend_forward_context import ( # isort: skip from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity) set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped] import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -1257,6 +1257,7 @@ class NPUModelRunner(GPUModelRunner):
logits_indices=logits_indices, logits_indices=logits_indices,
) )
# TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community.
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
valid_sampled_token_ids: torch.Tensor | list[list[int]], valid_sampled_token_ids: torch.Tensor | list[list[int]],
@@ -1273,10 +1274,147 @@ class NPUModelRunner(GPUModelRunner):
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
draft_token_ids = None draft_token_ids = None
else: else:
draft_token_ids = self.drafter.generate_token_ids( if self.speculative_config.method in ("suffix", "ngram"):
valid_sampled_token_ids, sampling_metadata, scheduler_output, draft_token_ids = self.drafter.generate_token_ids(
spec_decode_metadata, positions, num_scheduled_tokens, valid_sampled_token_ids, sampling_metadata,
hidden_states, aux_hidden_states) scheduler_output, spec_decode_metadata, positions,
num_scheduled_tokens, hidden_states, aux_hidden_states)
elif self.speculative_config.use_eagle():
common_attn_metadata = self.spec_decode_common_attn_metadata
sampled_token_ids = valid_sampled_token_ids
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
assert self.drafter is not None
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
sampled_token_ids, self.requests, self.input_batch,
scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
assert self.drafter is not None
next_token_ids, valid_sampled_tokens_count = \
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.long_seq_metadata # type: ignore
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None # type: ignore
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([
h[:num_scheduled_tokens]
for h in aux_hidden_states
],
dim=-1)
else:
target_hidden_states = hidden_states[:
num_scheduled_tokens]
else:
if self.pcp_size > 1:
assert common_attn_metadata is not None
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
assert common_attn_metadata is not None
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
assert self.drafter is not None
common_attn_metadata, token_indices =\
self.drafter.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
assert self.drafter is not None
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
assert self.drafter is not None
draft_token_ids = self.drafter._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
return draft_token_ids return draft_token_ids
@staticmethod @staticmethod