Files
xc-llm-ascend/vllm_ascend/spec_decode/__init__.py

39 lines
1.7 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return AscendNgramProposer(vllm_config, runner)
elif method == "suffix":
return AscendSuffixDecodingProposer(vllm_config, runner)
Add Medusa speculative decoding support for vllm_ascend (#5668) ### What this PR does / why we need it? `vllm_ascend` already supports several speculative decoding strategies such as MTP, EAGLE, N-gram, and suffix decoding. However, Medusa is not yet supported. Medusa is an efficient speculative decoding framework that leverages a lightweight draft model to propose multiple tokens in a single step, which can significantly improve decoding throughput and reduce latency. To enable Medusa-based speculative decoding on Ascend hardware and provide more decoding options for users, this PR adds Medusa support into the `vllm_ascend` speculative decoding pipeline. ### Does this PR introduce _any_ user-facing change? This PR introduces Medusa speculative decoding as an additional speculative decoding method: ✔ Adds `MedusaProposer` and integrates it into the speculative decoding registry ✔ Extends `SpecDcodeType` with a `MEDUSA` enum entry ✔ Updates `NPUModelRunner` to recognize and invoke Medusa during decoding ✔ Adds Medusa-specific handling in the draft token generation logic ✔ Ensures backward compatibility — Medusa is only used when explicitly enabled Key code changes include: * New file: `vllm_ascend/spec_decode/medusa_proposer.py` * Register Medusa in `get_spec_decode_method` * Extend proposer type hints to include `MedusaProposer` * Add a Medusa-specific branch in `generate_draft_token_ids` * Pass `sample_hidden_states` required by Medusa ### How was this patch tested? Medusa is implemented as a new proposer class (`MedusaProposer`) following the existing speculative decoding interface. The integration works as follows: 1. Users enable Medusa via the speculative decoding configuration. 2. `get_spec_decode_method()` returns a `MedusaProposer` instance when `method="medusa"`. 3. During decoding, `NPUModelRunner` detects that the active drafter is a `MedusaProposer`. 4. Instead of the generic speculative decoding path, the Medusa-specific `generate_token_ids()` method is invoked, which consumes: * `valid_sampled_token_ids` * `sampling_metadata` * `spec_decode_metadata` * `sample_hidden_states` 5. The proposed tokens are validated by the target model as usual. When Medusa is not enabled, the decoding pipeline behaves exactly as before, ensuring full backward compatibility. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d Signed-off-by: simplzyu <191163281@qq.com> Signed-off-by: simplzyu <zhenyuguo@cmbchina.com>
2026-01-23 14:14:23 +08:00
elif method == "medusa":
return AscendMedusaProposer(vllm_config, device)
elif method in ("eagle", "eagle3"):
return AscendEagleProposer(vllm_config, device, runner)
elif method == "mtp":
return AscendMtpProposer(vllm_config, device, runner)
else:
raise ValueError(f"Unknown speculative decoding method: {method}")