[SpecDecode] Add spec decode support (#500)

### What this PR does / why we need it?
Backport: https://github.com/vllm-project/vllm-ascend/pull/252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: https://github.com/vllm-project/vllm-ascend/pull/423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
Mengqing Cao
2025-04-17 20:16:32 +08:00
committed by GitHub
parent b71f193cb0
commit 6ee7f5cf71
27 changed files with 5813 additions and 11 deletions

View File

@@ -511,9 +511,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

View File

@@ -13,4 +13,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# What's Patched and how it works:
# ** File: worker/patch_common/patch_metrics.py **
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and
# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
# Why:
# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
# How
# Replace it with the corresponding npu method
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# How
# Change to use `current_platform.Event` to determine whether to return None
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
# How
# Make speculative decoding extensible to different backends.
# - support attention metadata register to the set supported spec decode
# - offer a api in platform to determine whether spec decode is supported,
# and deprecate is_cuda_alike in it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
# Why:
# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`.
# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of
# `FlashAttentionMetadata`
# How
# ditto
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
# current_platform.is_cuda_alike()
# 0.8.4 patch doc:
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
# ...
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa

View File

@@ -0,0 +1,88 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional, Union
import torch
import torch_npu
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Timer = Callable[[], float]
# TODO: revert this patch when the cuda hard code is removed in vllm
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'npu') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'npu':
self._copy_stream = torch_npu.npu.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
"""
TODO: torch.cuda.xxx --> torch.npu.xxx
Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a NPU event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.npu.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
AsyncMetricsCollector.init_tensors = init_tensors
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
# TODO: supports_gpu_multi_step is False in ASCEND
if isinstance(self.model_runner, TP1DraftModelRunner) and \
self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)

View File

@@ -0,0 +1,151 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, Optional
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import \
SpecDecodeBaseSampler
from vllm.model_executor.layers.typical_acceptance_sampler import \
TypicalAcceptanceSampler
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker_base import WorkerBase
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
logger = init_logger(__name__)
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_model_config.hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
# Note: The current version of the MTP module doer not support
# the use of TP1DraftModelRunner
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
"deepseek_mtp":
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler()
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if draft_model_config and \
draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
SpecDecodeWorker.create_worker = classmethod(create_worker)

View File

@@ -0,0 +1,320 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional
import torch
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)
from vllm_ascend.attention.attention import AscendMetadata
logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(model_runner)
self.indices_of_seq_with_bonus_tokens = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, AscendMetadata)
attn_metadata.advance_step(model_input, sampled_token_ids,
self.block_size, num_seqs, num_queries)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
# Notes: If vllm_ascend supports LORA, we need to
# add the following two params.
# lora_mapping=model_input.lora_mapping,
# lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not allow_gpu_advance_step:
return False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
# TODO: Add support for ASCEND when outer multi_step_worker
# could work correct.
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
return False
# TODO: Add support for LORA
if self.lora_config:
return False
# TODO: Add soft-tuning prompt adapter support
return not self.prompt_adapter_config
def set_indices_of_seq_with_bonus_tokens(self,
indices_of_seq_with_bonus_tokens):
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
@torch.inference_mode()
def execute_model(
self,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Detect exec mode
assert model_input.attn_metadata is not None
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
model_executable = self.model
hidden_states = previous_hidden_states
outputs: List[SamplerOutput] = []
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}
compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**model_execute_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)
if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
assert num_tokens_per_seq == 1
count = 0
for i in range(nums_seqs):
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
count]
if i != bonus_seq_idx:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output.sampled_token_ids[
i, :] = model_input.input_tokens[bonus_seq_idx]
else:
count += 1
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs