[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -16,23 +16,23 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
|
||||
|
||||
|
||||
def get_spec_decode_method(method, vllm_config, device, runner):
|
||||
if method == "ngram":
|
||||
return NgramProposer(vllm_config, device, runner)
|
||||
elif method in ("eagle", "eagle3"):
|
||||
return EagleProposer(vllm_config, device, runner)
|
||||
elif method == "mtp":
|
||||
return MtpProposer(vllm_config, device, runner)
|
||||
return AscendNgramProposer(vllm_config, runner)
|
||||
elif method == "suffix":
|
||||
return SuffixDecodingProposer(vllm_config, device, runner)
|
||||
return AscendSuffixDecodingProposer(vllm_config, runner)
|
||||
elif method == "medusa":
|
||||
return MedusaProposer(vllm_config, device, runner)
|
||||
return AscendMedusaProposer(vllm_config, device)
|
||||
elif method in ("eagle", "eagle3"):
|
||||
return AscendEagleProposer(vllm_config, device, runner)
|
||||
elif method == "mtp":
|
||||
return AscendMtpProposer(vllm_config, device, runner)
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {method}")
|
||||
|
||||
@@ -30,8 +30,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@@ -81,7 +80,7 @@ def split_inputs_tp_to_sp(hidden_states, out):
|
||||
return out[:padded_num_tokens_per_rank]
|
||||
|
||||
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
class AscendEagleProposer(EagleProposer):
|
||||
_runnable: ACLGraphWrapper | Callable
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import enum
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
|
||||
class SpecDcodeType(enum.Enum):
|
||||
NGRAM = 0
|
||||
EAGLE = 1
|
||||
EAGLE3 = 2
|
||||
MTP = 4
|
||||
SUFFIX = 5
|
||||
MEDUSA = 6
|
||||
|
||||
|
||||
class Proposer:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None):
|
||||
pass
|
||||
|
||||
def load_model(self, model):
|
||||
"""Called by load_model in model_runner"""
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
in_graph_capturing: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
):
|
||||
"""Called by dummy_run in model_runner"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
):
|
||||
"""Called by execute_model in model_runner"""
|
||||
raise NotImplementedError
|
||||
@@ -1,36 +1,20 @@
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MedusaProposer(VllmMedusaProposer):
|
||||
class AscendMedusaProposer(MedusaProposer):
|
||||
"""
|
||||
Medusa proposer class for generating token sequences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner,
|
||||
):
|
||||
# Save config parameters
|
||||
self.name = SpecDcodeType.MEDUSA
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.runner = runner
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
@@ -62,14 +46,12 @@ class MedusaProposer(VllmMedusaProposer):
|
||||
self.model(hidden_states)
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
def generate_token_ids(
|
||||
def propose(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
@@ -84,7 +66,7 @@ class MedusaProposer(VllmMedusaProposer):
|
||||
indices = offsets + num_accepted_tokens - 1
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.propose(
|
||||
spec_token_ids = super().propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
@@ -15,11 +15,11 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.utils import lmhead_tp_enable
|
||||
|
||||
|
||||
class MtpProposer(EagleProposer):
|
||||
class AscendMtpProposer(AscendEagleProposer):
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: nn.Module | ACLGraphWrapper
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer
|
||||
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
|
||||
|
||||
class NgramProposer(VllmNgramProposer, Proposer):
|
||||
def __init__(self, vllm_config, device, runner):
|
||||
super().__init__(vllm_config)
|
||||
self.name = SpecDcodeType.NGRAM
|
||||
self.device = device
|
||||
class AscendNgramProposer(NgramProposer):
|
||||
def __init__(self, vllm_config, runner):
|
||||
self.runner = runner
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
@@ -24,26 +19,22 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
||||
in_graph_capturing=None,
|
||||
num_reqs=None,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
aclgraph_runtime_mode=None,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False,
|
||||
):
|
||||
pass
|
||||
|
||||
def generate_token_ids(
|
||||
def propose(
|
||||
self,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata=None,
|
||||
scheduler_output=None,
|
||||
spec_decode_metadata=None,
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
aux_hidden_states=None,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_tokens_no_spec=None,
|
||||
token_ids_cpu=None,
|
||||
slot_masks: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
|
||||
) -> list[list[int]]:
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
continue
|
||||
@@ -64,7 +55,7 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
||||
valid_ngram_requests.append(i)
|
||||
|
||||
draft_token_ids = self.batch_propose(
|
||||
len(valid_sampled_token_ids),
|
||||
len(sampled_token_ids),
|
||||
valid_ngram_requests,
|
||||
self.runner.input_batch.num_tokens_no_spec,
|
||||
self.runner.input_batch.token_ids_cpu,
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer
|
||||
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
|
||||
|
||||
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
||||
def __init__(self, vllm_config, device, runner):
|
||||
class AscendSuffixDecodingProposer(SuffixDecodingProposer):
|
||||
def __init__(self, vllm_config, runner):
|
||||
super().__init__(vllm_config)
|
||||
self.name = SpecDcodeType.SUFFIX
|
||||
self.device = device
|
||||
self.runner = runner
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens,
|
||||
@@ -24,23 +13,12 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
||||
in_graph_capturing=None,
|
||||
num_reqs=None,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
aclgraph_runtime_mode=None,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False,
|
||||
):
|
||||
pass
|
||||
|
||||
def generate_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata=None,
|
||||
scheduler_output=None,
|
||||
spec_decode_metadata=None,
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
aux_hidden_states=None,
|
||||
) -> list[list[int]]:
|
||||
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
|
||||
return draft_token_ids
|
||||
def propose(self, valid_sampled_token_ids):
|
||||
return super().propose(self.runner.input_batch, valid_sampled_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user