Add Medusa speculative decoding support for vllm_ascend (#5668)
### What this PR does / why we need it?
`vllm_ascend` already supports several speculative decoding strategies
such as MTP, EAGLE, N-gram, and suffix decoding. However, Medusa is not
yet supported. Medusa is an efficient speculative decoding framework
that leverages a lightweight draft model to propose multiple tokens in a
single step, which can significantly improve decoding throughput and
reduce latency.
To enable Medusa-based speculative decoding on Ascend hardware and
provide more decoding options for users, this PR adds Medusa support
into the `vllm_ascend` speculative decoding pipeline.
### Does this PR introduce _any_ user-facing change?
This PR introduces Medusa speculative decoding as an additional
speculative decoding method:
✔ Adds `MedusaProposer` and integrates it into the speculative decoding
registry
✔ Extends `SpecDcodeType` with a `MEDUSA` enum entry
✔ Updates `NPUModelRunner` to recognize and invoke Medusa during
decoding
✔ Adds Medusa-specific handling in the draft token generation logic
✔ Ensures backward compatibility — Medusa is only used when explicitly
enabled
Key code changes include:
* New file: `vllm_ascend/spec_decode/medusa_proposer.py`
* Register Medusa in `get_spec_decode_method`
* Extend proposer type hints to include `MedusaProposer`
* Add a Medusa-specific branch in `generate_draft_token_ids`
* Pass `sample_hidden_states` required by Medusa
### How was this patch tested?
Medusa is implemented as a new proposer class (`MedusaProposer`)
following the existing speculative decoding interface. The integration
works as follows:
1. Users enable Medusa via the speculative decoding configuration.
2. `get_spec_decode_method()` returns a `MedusaProposer` instance when
`method="medusa"`.
3. During decoding, `NPUModelRunner` detects that the active drafter is
a `MedusaProposer`.
4. Instead of the generic speculative decoding path, the Medusa-specific
`generate_token_ids()` method is invoked, which consumes:
* `valid_sampled_token_ids`
* `sampling_metadata`
* `spec_decode_metadata`
* `sample_hidden_states`
5. The proposed tokens are validated by the target model as usual.
When Medusa is not enabled, the decoding pipeline behaves exactly as
before, ensuring full backward compatibility.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: simplzyu <191163281@qq.com>
Signed-off-by: simplzyu <zhenyuguo@cmbchina.com>
This commit is contained in:
98
vllm_ascend/spec_decode/medusa_proposer.py
Normal file
98
vllm_ascend/spec_decode/medusa_proposer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MedusaProposer(VllmMedusaProposer):
|
||||
"""
|
||||
Medusa proposer class for generating token sequences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner,
|
||||
):
|
||||
# Save config parameters
|
||||
self.name = SpecDcodeType.MEDUSA
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_size = (vllm_config.speculative_config.draft_model_config.
|
||||
get_hidden_size())
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.runner = runner
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
in_graph_capturing: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False):
|
||||
hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
with set_ascend_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_actual_tokens=0,
|
||||
in_profile_run=is_profile,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
self.model(hidden_states)
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
def generate_token_ids(self, valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
num_accepted_tokens = torch.tensor(
|
||||
[len(t) for t in valid_sampled_token_ids],
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
num_draft_tokens = torch.tensor(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
|
||||
offsets = torch.cumsum(num_draft_tokens + 1,
|
||||
dim=0) - (num_draft_tokens + 1)
|
||||
indices = offsets + num_accepted_tokens - 1
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return spec_token_ids
|
||||
Reference in New Issue
Block a user