[Spec Decode]clean up spec decode interface (#6947)

This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.

This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-03-05 14:30:10 +08:00
committed by GitHub
parent 2bd9c35788
commit 13777bf3f0
11 changed files with 194 additions and 315 deletions

View File

@@ -1,16 +1,11 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
class NgramProposer(VllmNgramProposer, Proposer):
def __init__(self, vllm_config, device, runner):
super().__init__(vllm_config)
self.name = SpecDcodeType.NGRAM
self.device = device
class AscendNgramProposer(NgramProposer):
def __init__(self, vllm_config, runner):
self.runner = runner
super().__init__(vllm_config)
def load_model(self, *args, **kwargs):
# No model to load.
@@ -24,26 +19,22 @@ class NgramProposer(VllmNgramProposer, Proposer):
in_graph_capturing=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
aclgraph_runtime_mode=None,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
):
pass
def generate_token_ids(
def propose(
self,
valid_sampled_token_ids,
sampling_metadata=None,
scheduler_output=None,
spec_decode_metadata=None,
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
aux_hidden_states=None,
sampled_token_ids: list[list[int]],
num_tokens_no_spec=None,
token_ids_cpu=None,
slot_masks: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
) -> list[list[int]]:
valid_ngram_requests = []
for i, sampled_ids in enumerate(valid_sampled_token_ids):
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
continue
@@ -64,7 +55,7 @@ class NgramProposer(VllmNgramProposer, Proposer):
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(valid_sampled_token_ids),
len(sampled_token_ids),
valid_ngram_requests,
self.runner.input_batch.num_tokens_no_spec,
self.runner.input_batch.token_ids_cpu,