[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,22 +1,11 @@
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer
|
||||
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
|
||||
|
||||
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
||||
def __init__(self, vllm_config, device, runner):
|
||||
class AscendSuffixDecodingProposer(SuffixDecodingProposer):
|
||||
def __init__(self, vllm_config, runner):
|
||||
super().__init__(vllm_config)
|
||||
self.name = SpecDcodeType.SUFFIX
|
||||
self.device = device
|
||||
self.runner = runner
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens,
|
||||
@@ -24,23 +13,12 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
||||
in_graph_capturing=None,
|
||||
num_reqs=None,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
aclgraph_runtime_mode=None,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False,
|
||||
):
|
||||
pass
|
||||
|
||||
def generate_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata=None,
|
||||
scheduler_output=None,
|
||||
spec_decode_metadata=None,
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
aux_hidden_states=None,
|
||||
) -> list[list[int]]:
|
||||
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
|
||||
return draft_token_ids
|
||||
def propose(self, valid_sampled_token_ids):
|
||||
return super().propose(self.runner.input_batch, valid_sampled_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user