[Spec Decode]clean up spec decode interface (#6947)

This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.

This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-03-05 14:30:10 +08:00
committed by GitHub
parent 2bd9c35788
commit 13777bf3f0
11 changed files with 194 additions and 315 deletions

View File

@@ -16,23 +16,23 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
elif method == "mtp":
return MtpProposer(vllm_config, device, runner)
return AscendNgramProposer(vllm_config, runner)
elif method == "suffix":
return SuffixDecodingProposer(vllm_config, device, runner)
return AscendSuffixDecodingProposer(vllm_config, runner)
elif method == "medusa":
return MedusaProposer(vllm_config, device, runner)
return AscendMedusaProposer(vllm_config, device)
elif method in ("eagle", "eagle3"):
return AscendEagleProposer(vllm_config, device, runner)
elif method == "mtp":
return AscendMtpProposer(vllm_config, device, runner)
else:
raise ValueError(f"Unknown speculative decoding method: {method}")

View File

@@ -30,8 +30,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -81,7 +80,7 @@ def split_inputs_tp_to_sp(hidden_states, out):
return out[:padded_num_tokens_per_rank]
class EagleProposer(VllmEagleProposer):
class AscendEagleProposer(EagleProposer):
_runnable: ACLGraphWrapper | Callable
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):

View File

@@ -1,53 +0,0 @@
import enum
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
class SpecDcodeType(enum.Enum):
NGRAM = 0
EAGLE = 1
EAGLE3 = 2
MTP = 4
SUFFIX = 5
MEDUSA = 6
class Proposer:
def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None):
pass
def load_model(self, model):
"""Called by load_model in model_runner"""
raise NotImplementedError
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
with_prefill: bool = False,
in_graph_capturing: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None,
):
"""Called by dummy_run in model_runner"""
raise NotImplementedError
def generate_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None,
):
"""Called by execute_model in model_runner"""
raise NotImplementedError

View File

@@ -1,36 +1,20 @@
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.spec_decode.interface import SpecDcodeType
logger = init_logger(__name__)
class MedusaProposer(VllmMedusaProposer):
class AscendMedusaProposer(MedusaProposer):
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner,
):
# Save config parameters
self.name = SpecDcodeType.MEDUSA
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size()
self.dtype = vllm_config.model_config.dtype
self.runner = runner
@torch.inference_mode()
def dummy_run(
self,
@@ -62,14 +46,12 @@ class MedusaProposer(VllmMedusaProposer):
self.model(hidden_states)
dummy_compute_logits(hidden_states)
def generate_token_ids(
def propose(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
spec_decode_metadata: SpecDecodeMetadata,
sample_hidden_states: torch.Tensor,
*args,
**kwargs,
):
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
# The input to the target model does not include draft tokens.
@@ -84,7 +66,7 @@ class MedusaProposer(VllmMedusaProposer):
indices = offsets + num_accepted_tokens - 1
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.propose(
spec_token_ids = super().propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)

View File

@@ -15,11 +15,11 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.utils import lmhead_tp_enable
class MtpProposer(EagleProposer):
class AscendMtpProposer(AscendEagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: nn.Module | ACLGraphWrapper

View File

@@ -1,16 +1,11 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
class NgramProposer(VllmNgramProposer, Proposer):
def __init__(self, vllm_config, device, runner):
super().__init__(vllm_config)
self.name = SpecDcodeType.NGRAM
self.device = device
class AscendNgramProposer(NgramProposer):
def __init__(self, vllm_config, runner):
self.runner = runner
super().__init__(vllm_config)
def load_model(self, *args, **kwargs):
# No model to load.
@@ -24,26 +19,22 @@ class NgramProposer(VllmNgramProposer, Proposer):
in_graph_capturing=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
aclgraph_runtime_mode=None,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
):
pass
def generate_token_ids(
def propose(
self,
valid_sampled_token_ids,
sampling_metadata=None,
scheduler_output=None,
spec_decode_metadata=None,
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
aux_hidden_states=None,
sampled_token_ids: list[list[int]],
num_tokens_no_spec=None,
token_ids_cpu=None,
slot_masks: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
) -> list[list[int]]:
valid_ngram_requests = []
for i, sampled_ids in enumerate(valid_sampled_token_ids):
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
continue
@@ -64,7 +55,7 @@ class NgramProposer(VllmNgramProposer, Proposer):
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(valid_sampled_token_ids),
len(sampled_token_ids),
valid_ngram_requests,
self.runner.input_batch.num_tokens_no_spec,
self.runner.input_batch.token_ids_cpu,

View File

@@ -1,22 +1,11 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
def __init__(self, vllm_config, device, runner):
class AscendSuffixDecodingProposer(SuffixDecodingProposer):
def __init__(self, vllm_config, runner):
super().__init__(vllm_config)
self.name = SpecDcodeType.SUFFIX
self.device = device
self.runner = runner
def load_model(self, *args, **kwargs):
# No model to load.
pass
@torch.inference_mode()
def dummy_run(
self,
num_tokens,
@@ -24,23 +13,12 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
in_graph_capturing=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
aclgraph_runtime_mode=None,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
):
pass
def generate_token_ids(
self,
valid_sampled_token_ids,
sampling_metadata=None,
scheduler_output=None,
spec_decode_metadata=None,
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
aux_hidden_states=None,
) -> list[list[int]]:
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
return draft_token_ids
def propose(self, valid_sampled_token_ids):
return super().propose(self.runner.input_batch, valid_sampled_token_ids)