[SpecDecode] Add spec decode support (#500)

### What this PR does / why we need it?
Backport: https://github.com/vllm-project/vllm-ascend/pull/252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: https://github.com/vllm-project/vllm-ascend/pull/423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
Mengqing Cao
2025-04-17 20:16:32 +08:00
committed by GitHub
parent b71f193cb0
commit 6ee7f5cf71
27 changed files with 5813 additions and 11 deletions

View File

@@ -13,4 +13,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# What's Patched and how it works:
# ** File: worker/patch_common/patch_metrics.py **
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and
# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
# Why:
# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
# How
# Replace it with the corresponding npu method
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# How
# Change to use `current_platform.Event` to determine whether to return None
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
# How
# Make speculative decoding extensible to different backends.
# - support attention metadata register to the set supported spec decode
# - offer a api in platform to determine whether spec decode is supported,
# and deprecate is_cuda_alike in it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
# Why:
# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`.
# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of
# `FlashAttentionMetadata`
# How
# ditto
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
# current_platform.is_cuda_alike()
# 0.8.4 patch doc:
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
# ...
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa

View File

@@ -0,0 +1,88 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional, Union
import torch
import torch_npu
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Timer = Callable[[], float]
# TODO: revert this patch when the cuda hard code is removed in vllm
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'npu') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'npu':
self._copy_stream = torch_npu.npu.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
"""
TODO: torch.cuda.xxx --> torch.npu.xxx
Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a NPU event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.npu.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
AsyncMetricsCollector.init_tensors = init_tensors
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
# TODO: supports_gpu_multi_step is False in ASCEND
if isinstance(self.model_runner, TP1DraftModelRunner) and \
self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)

View File

@@ -0,0 +1,151 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, Optional
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import \
SpecDecodeBaseSampler
from vllm.model_executor.layers.typical_acceptance_sampler import \
TypicalAcceptanceSampler
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker_base import WorkerBase
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
logger = init_logger(__name__)
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_model_config.hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
# Note: The current version of the MTP module doer not support
# the use of TP1DraftModelRunner
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
"deepseek_mtp":
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler()
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if draft_model_config and \
draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
SpecDecodeWorker.create_worker = classmethod(create_worker)