Files
xc-llm-ascend/vllm_ascend/spec_decode/__init__.py
simplzyu f8d03d21f1 Add Medusa speculative decoding support for vllm_ascend (#5668)
### What this PR does / why we need it?
`vllm_ascend` already supports several speculative decoding strategies
such as MTP, EAGLE, N-gram, and suffix decoding. However, Medusa is not
yet supported. Medusa is an efficient speculative decoding framework
that leverages a lightweight draft model to propose multiple tokens in a
single step, which can significantly improve decoding throughput and
reduce latency.

To enable Medusa-based speculative decoding on Ascend hardware and
provide more decoding options for users, this PR adds Medusa support
into the `vllm_ascend` speculative decoding pipeline.

### Does this PR introduce _any_ user-facing change?

This PR introduces Medusa speculative decoding as an additional
speculative decoding method:

✔ Adds `MedusaProposer` and integrates it into the speculative decoding
registry
✔ Extends `SpecDcodeType` with a `MEDUSA` enum entry
✔ Updates `NPUModelRunner` to recognize and invoke Medusa during
decoding
✔ Adds Medusa-specific handling in the draft token generation logic
✔ Ensures backward compatibility — Medusa is only used when explicitly
enabled

Key code changes include:

* New file: `vllm_ascend/spec_decode/medusa_proposer.py`
* Register Medusa in `get_spec_decode_method`
* Extend proposer type hints to include `MedusaProposer`
* Add a Medusa-specific branch in `generate_draft_token_ids`
* Pass `sample_hidden_states` required by Medusa

### How was this patch tested?

Medusa is implemented as a new proposer class (`MedusaProposer`)
following the existing speculative decoding interface. The integration
works as follows:

1. Users enable Medusa via the speculative decoding configuration.
2. `get_spec_decode_method()` returns a `MedusaProposer` instance when
`method="medusa"`.
3. During decoding, `NPUModelRunner` detects that the active drafter is
a `MedusaProposer`.
4. Instead of the generic speculative decoding path, the Medusa-specific
`generate_token_ids()` method is invoked, which consumes:

   * `valid_sampled_token_ids`
   * `sampling_metadata`
   * `spec_decode_metadata`
   * `sample_hidden_states`
5. The proposed tokens are validated by the target model as usual.

When Medusa is not enabled, the decoding pipeline behaves exactly as
before, ensuring full backward compatibility.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: simplzyu <191163281@qq.com>
Signed-off-by: simplzyu <zhenyuguo@cmbchina.com>
2026-01-23 14:14:23 +08:00

40 lines
1.7 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
elif method == "mtp":
return MtpProposer(vllm_config, device, runner)
elif method == 'suffix':
return SuffixDecodingProposer(vllm_config, device, runner)
elif method == "medusa":
return MedusaProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")