[SpecDecode] Add spec decode support (#500)
### What this PR does / why we need it? Backport: https://github.com/vllm-project/vllm-ascend/pull/252 This support speculative decoding in Ascend, including speculating with a draft model、by matching n-grams in the prompt、using MLP speculators and using EAGLE based draft models. Backport: https://github.com/vllm-project/vllm-ascend/pull/423 spec decode MultiStepWorker support TP1DraftModelRunner fully, support run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA. 1. before this pr, `MultiStepWorker` would not step into the branch using NPU prepare, but only into the branch using CPU prepare (`line 52` of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has `no effect` on the `correct operation` of speculative decoding and the performance of the two branches is basically the same as of the current version, I support entering this branch in this PR. In general, there are two main changes in `patch_multi_step_worker.py`: first, the `is_cuda_like()` check is removed and the `TP1DraftModelRunner` rewritten in vllm_ascend is used; second, the `supports_gpu_multi_step()` function is made to return true on NPU devices when outer Multi_step_worker could work correct. 3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU, but not MLA. The relevant adaptation is in `vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why the `input_positions` of `model_input.attn_metadata` in vllm-ascend needs to be added in `execute_model`, it is done in `model_runner.py`, so I also made corresponding changes. Otherwise, when atten_backend is MLA, it will prompt that input_positions cannot be found. 4. I commented out two lines in `draft_model_runner.py` in `line118` to support the scenario of K>1. ``` # lora_mapping=model_input.lora_mapping, # lora_requests=model_input.lora_requests, ``` I added comments. In the future, when vllm-ascend supports lora feature, the changes here can be restored. TODO: - [ ] revert the patch when the related issues are addressed in vllm ### How was this patch tested? CI passed with new added test. - e2e test for medusa proposer: tests/singlecard/spec_decode/e2e/test_medusa_correctness.py - e2e test for mlp proposer: tests/singlecard/spec_decode/e2e/test_mlp_correctness.py - e2e test for n-gram proposer: tests/singlecard/spec_decode/e2e/test_ngram_correctness.py Tests for patched files: - tests/singlecard/spec_decode/test_dynamic_spec_decode.py - tests/singlecard/spec_decode/test_multi_step_worker.py - tests/singlecard/spec_decode/test_ngram_worker.py - tests/singlecard/spec_decode/test_spec_decode_worker.py --------- Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
@@ -13,4 +13,68 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
#
|
||||
|
||||
# What's Patched and how it works:
|
||||
# ** File: worker/patch_common/patch_metrics.py **
|
||||
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and
|
||||
# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
|
||||
# Why:
|
||||
# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and
|
||||
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
|
||||
# How:
|
||||
# Replace it with the corresponding npu method
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# https://github.com/vllm-project/vllm/pull/14411
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm.
|
||||
#
|
||||
# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
|
||||
# Why:
|
||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
||||
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
|
||||
# How:
|
||||
# Change to use `current_platform.Event` to determine whether to return None
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# https://github.com/vllm-project/vllm/pull/14411
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
||||
# Why:
|
||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
||||
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
|
||||
# How:
|
||||
# Make speculative decoding extensible to different backends.
|
||||
# - support attention metadata register to the set supported spec decode
|
||||
# - offer a api in platform to determine whether spec decode is supported,
|
||||
# and deprecate is_cuda_alike in it.
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
||||
# Why:
|
||||
# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`.
|
||||
# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of
|
||||
# `FlashAttentionMetadata`
|
||||
# How:
|
||||
# ditto
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
|
||||
# current_platform.is_cuda_alike()
|
||||
# 0.8.4 patch doc:
|
||||
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
|
||||
# ...
|
||||
|
||||
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||
|
||||
88
vllm_ascend/patch/worker/patch_common/patch_metrics.py
Normal file
88
vllm_ascend/patch/worker/patch_common/patch_metrics.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
# TODO: revert this patch when the cuda hard code is removed in vllm
|
||||
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
|
||||
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
|
||||
|
||||
|
||||
def init_tensors(self,
|
||||
rank: int,
|
||||
device_type: Union[torch.device, str] = 'npu') -> None:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if device_type == 'npu':
|
||||
self._copy_stream = torch_npu.npu.Stream()
|
||||
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
|
||||
"""
|
||||
TODO: torch.cuda.xxx --> torch.npu.xxx
|
||||
Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a NPU event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.npu.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
|
||||
AsyncMetricsCollector.init_tensors = init_tensors
|
||||
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
|
||||
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async
|
||||
@@ -0,0 +1,87 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
|
||||
# TODO: supports_gpu_multi_step is False in ASCEND
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner) and \
|
||||
self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
|
||||
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
|
||||
@@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import \
|
||||
SpecDecodeBaseSampler
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import \
|
||||
TypicalAcceptanceSampler
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_mqa_scorer: bool,
|
||||
disable_by_batch_size: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
num_speculative_tokens: int,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
enable_lm_head_weight_load = False
|
||||
num_spec_prefill_steps = 1
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'vllm_config'].parallel_config
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
draft_worker_kwargs[
|
||||
"device_type"] = scorer_worker.device_config.device.type
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
if draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
elif draft_model_config.hf_config.model_type == "medusa":
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
# Note: The current version of the MTP module doer not support
|
||||
# the use of TP1DraftModelRunner
|
||||
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
|
||||
"deepseek_mtp":
|
||||
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
f"{draft_model_config.hf_config.model_type} "
|
||||
"does not support TP > 1 yet")
|
||||
|
||||
allow_zero_draft_token_step = False
|
||||
|
||||
# Load lm_head weight for eagle in init_device
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
enable_lm_head_weight_load = True
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||
num_spec_prefill_steps = num_speculative_tokens
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler()
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
logger.info(
|
||||
"[Speculative Decoding] Configuring"
|
||||
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
|
||||
|
||||
if not disable_mqa_scorer:
|
||||
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"MQA is only available with flash attn backend.")
|
||||
|
||||
if draft_model_config and \
|
||||
draft_model_config.max_model_len < \
|
||||
scorer_worker.model_config.max_model_len:
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"draft model max_model_len is smaller than the target "
|
||||
"model max_model_len.")
|
||||
|
||||
if not scorer_worker.model_runner.model_config.enforce_eager:
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"target model is not running in eager mode.")
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_mqa_scorer=disable_mqa_scorer,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
||||
enable_lm_head_weight_load=enable_lm_head_weight_load,
|
||||
num_spec_prefill_steps=num_spec_prefill_steps)
|
||||
|
||||
|
||||
SpecDecodeWorker.create_worker = classmethod(create_worker)
|
||||
Reference in New Issue
Block a user