Add Medusa speculative decoding support for vllm_ascend (#5668)

### What this PR does / why we need it?
`vllm_ascend` already supports several speculative decoding strategies
such as MTP, EAGLE, N-gram, and suffix decoding. However, Medusa is not
yet supported. Medusa is an efficient speculative decoding framework
that leverages a lightweight draft model to propose multiple tokens in a
single step, which can significantly improve decoding throughput and
reduce latency.

To enable Medusa-based speculative decoding on Ascend hardware and
provide more decoding options for users, this PR adds Medusa support
into the `vllm_ascend` speculative decoding pipeline.

### Does this PR introduce _any_ user-facing change?

This PR introduces Medusa speculative decoding as an additional
speculative decoding method:

✔ Adds `MedusaProposer` and integrates it into the speculative decoding
registry
✔ Extends `SpecDcodeType` with a `MEDUSA` enum entry
✔ Updates `NPUModelRunner` to recognize and invoke Medusa during
decoding
✔ Adds Medusa-specific handling in the draft token generation logic
✔ Ensures backward compatibility — Medusa is only used when explicitly
enabled

Key code changes include:

* New file: `vllm_ascend/spec_decode/medusa_proposer.py`
* Register Medusa in `get_spec_decode_method`
* Extend proposer type hints to include `MedusaProposer`
* Add a Medusa-specific branch in `generate_draft_token_ids`
* Pass `sample_hidden_states` required by Medusa

### How was this patch tested?

Medusa is implemented as a new proposer class (`MedusaProposer`)
following the existing speculative decoding interface. The integration
works as follows:

1. Users enable Medusa via the speculative decoding configuration.
2. `get_spec_decode_method()` returns a `MedusaProposer` instance when
`method="medusa"`.
3. During decoding, `NPUModelRunner` detects that the active drafter is
a `MedusaProposer`.
4. Instead of the generic speculative decoding path, the Medusa-specific
`generate_token_ids()` method is invoked, which consumes:

   * `valid_sampled_token_ids`
   * `sampling_metadata`
   * `spec_decode_metadata`
   * `sample_hidden_states`
5. The proposed tokens are validated by the target model as usual.

When Medusa is not enabled, the decoding pipeline behaves exactly as
before, ensuring full backward compatibility.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: simplzyu <191163281@qq.com>
Signed-off-by: simplzyu <zhenyuguo@cmbchina.com>
This commit is contained in:
simplzyu
2026-01-23 14:14:23 +08:00
committed by GitHub
parent a69ef10c3a
commit f8d03d21f1
4 changed files with 111 additions and 2 deletions

View File

@@ -100,6 +100,7 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.sampler import AscendSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_moe_model,
@@ -363,7 +364,8 @@ class NPUModelRunner(GPUModelRunner):
def _set_up_drafter(self):
# Set up speculative decoding.
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
SuffixDecodingProposer]] = None
SuffixDecodingProposer,
MedusaProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
@@ -1288,6 +1290,7 @@ class NPUModelRunner(GPUModelRunner):
hidden_states: torch.Tensor,
attn_metadata: dict[str, Any],
aux_hidden_states: torch.Tensor = None,
sample_hidden_states: torch.Tensor = None
) -> Optional[list[list[int]]]:
if not self.drafter:
# Speculative decoding is not enabled.
@@ -1298,7 +1301,10 @@ class NPUModelRunner(GPUModelRunner):
valid_sampled_token_ids, sampling_metadata,
scheduler_output, spec_decode_metadata, positions,
num_scheduled_tokens, hidden_states, aux_hidden_states)
elif isinstance(self.drafter, MedusaProposer):
draft_token_ids = self.drafter.generate_token_ids(
valid_sampled_token_ids, sampling_metadata,
spec_decode_metadata, sample_hidden_states)
elif self.speculative_config.use_eagle():
common_attn_metadata = self.spec_decode_common_attn_metadata
sampled_token_ids = valid_sampled_token_ids
@@ -1660,6 +1666,7 @@ class NPUModelRunner(GPUModelRunner):
hidden_states,
attn_metadata,
aux_hidden_states,
sample_hidden_states
)
self._copy_draft_token_ids_to_cpu(scheduler_output)