[SpecDecode] Add spec decode support (#500)
### What this PR does / why we need it? Backport: https://github.com/vllm-project/vllm-ascend/pull/252 This support speculative decoding in Ascend, including speculating with a draft model、by matching n-grams in the prompt、using MLP speculators and using EAGLE based draft models. Backport: https://github.com/vllm-project/vllm-ascend/pull/423 spec decode MultiStepWorker support TP1DraftModelRunner fully, support run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA. 1. before this pr, `MultiStepWorker` would not step into the branch using NPU prepare, but only into the branch using CPU prepare (`line 52` of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has `no effect` on the `correct operation` of speculative decoding and the performance of the two branches is basically the same as of the current version, I support entering this branch in this PR. In general, there are two main changes in `patch_multi_step_worker.py`: first, the `is_cuda_like()` check is removed and the `TP1DraftModelRunner` rewritten in vllm_ascend is used; second, the `supports_gpu_multi_step()` function is made to return true on NPU devices when outer Multi_step_worker could work correct. 3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU, but not MLA. The relevant adaptation is in `vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why the `input_positions` of `model_input.attn_metadata` in vllm-ascend needs to be added in `execute_model`, it is done in `model_runner.py`, so I also made corresponding changes. Otherwise, when atten_backend is MLA, it will prompt that input_positions cannot be found. 4. I commented out two lines in `draft_model_runner.py` in `line118` to support the scenario of K>1. ``` # lora_mapping=model_input.lora_mapping, # lora_requests=model_input.lora_requests, ``` I added comments. In the future, when vllm-ascend supports lora feature, the changes here can be restored. TODO: - [ ] revert the patch when the related issues are addressed in vllm ### How was this patch tested? CI passed with new added test. - e2e test for medusa proposer: tests/singlecard/spec_decode/e2e/test_medusa_correctness.py - e2e test for mlp proposer: tests/singlecard/spec_decode/e2e/test_mlp_correctness.py - e2e test for n-gram proposer: tests/singlecard/spec_decode/e2e/test_ngram_correctness.py Tests for patched files: - tests/singlecard/spec_decode/test_dynamic_spec_decode.py - tests/singlecard/spec_decode/test_multi_step_worker.py - tests/singlecard/spec_decode/test_ngram_worker.py - tests/singlecard/spec_decode/test_spec_decode_worker.py --------- Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
@@ -511,9 +511,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
|
||||
@@ -13,4 +13,68 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
#
|
||||
|
||||
# What's Patched and how it works:
|
||||
# ** File: worker/patch_common/patch_metrics.py **
|
||||
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and
|
||||
# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
|
||||
# Why:
|
||||
# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and
|
||||
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
|
||||
# How:
|
||||
# Replace it with the corresponding npu method
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# https://github.com/vllm-project/vllm/pull/14411
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm.
|
||||
#
|
||||
# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
|
||||
# Why:
|
||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
||||
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
|
||||
# How:
|
||||
# Change to use `current_platform.Event` to determine whether to return None
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# https://github.com/vllm-project/vllm/pull/14411
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
||||
# Why:
|
||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
||||
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
|
||||
# How:
|
||||
# Make speculative decoding extensible to different backends.
|
||||
# - support attention metadata register to the set supported spec decode
|
||||
# - offer a api in platform to determine whether spec decode is supported,
|
||||
# and deprecate is_cuda_alike in it.
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
||||
# Why:
|
||||
# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`.
|
||||
# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of
|
||||
# `FlashAttentionMetadata`
|
||||
# How:
|
||||
# ditto
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
|
||||
# current_platform.is_cuda_alike()
|
||||
# 0.8.4 patch doc:
|
||||
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
|
||||
# ...
|
||||
|
||||
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||
|
||||
88
vllm_ascend/patch/worker/patch_common/patch_metrics.py
Normal file
88
vllm_ascend/patch/worker/patch_common/patch_metrics.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
# TODO: revert this patch when the cuda hard code is removed in vllm
|
||||
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
|
||||
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
|
||||
|
||||
|
||||
def init_tensors(self,
|
||||
rank: int,
|
||||
device_type: Union[torch.device, str] = 'npu') -> None:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if device_type == 'npu':
|
||||
self._copy_stream = torch_npu.npu.Stream()
|
||||
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
|
||||
"""
|
||||
TODO: torch.cuda.xxx --> torch.npu.xxx
|
||||
Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a NPU event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.npu.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
|
||||
AsyncMetricsCollector.init_tensors = init_tensors
|
||||
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
|
||||
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async
|
||||
@@ -0,0 +1,87 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
|
||||
# TODO: supports_gpu_multi_step is False in ASCEND
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner) and \
|
||||
self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
|
||||
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
|
||||
@@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import \
|
||||
SpecDecodeBaseSampler
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import \
|
||||
TypicalAcceptanceSampler
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_mqa_scorer: bool,
|
||||
disable_by_batch_size: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
num_speculative_tokens: int,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
enable_lm_head_weight_load = False
|
||||
num_spec_prefill_steps = 1
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'vllm_config'].parallel_config
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
draft_worker_kwargs[
|
||||
"device_type"] = scorer_worker.device_config.device.type
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
if draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
elif draft_model_config.hf_config.model_type == "medusa":
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
# Note: The current version of the MTP module doer not support
|
||||
# the use of TP1DraftModelRunner
|
||||
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
|
||||
"deepseek_mtp":
|
||||
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
f"{draft_model_config.hf_config.model_type} "
|
||||
"does not support TP > 1 yet")
|
||||
|
||||
allow_zero_draft_token_step = False
|
||||
|
||||
# Load lm_head weight for eagle in init_device
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
enable_lm_head_weight_load = True
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||
num_spec_prefill_steps = num_speculative_tokens
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler()
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
logger.info(
|
||||
"[Speculative Decoding] Configuring"
|
||||
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
|
||||
|
||||
if not disable_mqa_scorer:
|
||||
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"MQA is only available with flash attn backend.")
|
||||
|
||||
if draft_model_config and \
|
||||
draft_model_config.max_model_len < \
|
||||
scorer_worker.model_config.max_model_len:
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"draft model max_model_len is smaller than the target "
|
||||
"model max_model_len.")
|
||||
|
||||
if not scorer_worker.model_runner.model_config.enforce_eager:
|
||||
disable_mqa_scorer = True
|
||||
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"target model is not running in eager mode.")
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_mqa_scorer=disable_mqa_scorer,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
||||
enable_lm_head_weight_load=enable_lm_head_weight_load,
|
||||
num_spec_prefill_steps=num_spec_prefill_steps)
|
||||
|
||||
|
||||
SpecDecodeWorker.create_worker = classmethod(create_worker)
|
||||
320
vllm_ascend/worker/draft_model_runner.py
Normal file
320
vllm_ascend/worker/draft_model_runner.py
Normal file
@@ -0,0 +1,320 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
from vllm_ascend.attention.attention import AscendMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A flag to enable debug prints for the updated input tensors
|
||||
# before each step.
|
||||
debug_advance_input = False
|
||||
# A flag to allow GPU advance step for draft model runner.
|
||||
# Set to False for debugging.
|
||||
allow_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
Since the draft model always execute k forward passes consecutively to
|
||||
generate k speculative tokens in a single speculative decoding step,
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
TODOs:
|
||||
1. Currently supports only flash-attn, add support for other attn_backends.
|
||||
2. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
if hasattr(
|
||||
model_runner,
|
||||
"return_hidden_states") and model_runner.return_hidden_states:
|
||||
raise ValueError(
|
||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||
)
|
||||
super().__init__(model_runner)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
|
||||
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
|
||||
last_output: SamplerOutput) -> ModelRunnerInputBase:
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
# Get num_seqs
|
||||
num_seqs = len(model_input.seq_lens)
|
||||
num_queries = len(model_input.query_lens)
|
||||
|
||||
# Get output tokens GPU tensor
|
||||
sampled_token_ids = last_output.sampled_token_ids
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, AscendMetadata)
|
||||
|
||||
attn_metadata.advance_step(model_input, sampled_token_ids,
|
||||
self.block_size, num_seqs, num_queries)
|
||||
|
||||
# Update sampling_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
self._update_sampling_metadata(sampling_metadata, num_seqs,
|
||||
num_queries)
|
||||
|
||||
# Create new input
|
||||
new_model_input = self._model_input_cls(
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=model_input.query_lens,
|
||||
# Notes: If vllm_ascend supports LORA, we need to
|
||||
# add the following two params.
|
||||
# lora_mapping=model_input.lora_mapping,
|
||||
# lora_requests=model_input.lora_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
is_prompt=False,
|
||||
)
|
||||
|
||||
# Ensure we skip CPU samples
|
||||
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
|
||||
# We can reuse sampling tensors since every decode iteration is the same
|
||||
new_model_input.sampling_metadata.reuse_sampling_tensors = True
|
||||
|
||||
if debug_advance_input:
|
||||
logger.debug("NEW INPUT: ")
|
||||
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
|
||||
logger.debug(" input_positions = %s",
|
||||
new_model_input.input_positions)
|
||||
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
|
||||
logger.debug(" query_lens = %d", new_model_input.query_lens)
|
||||
logger.debug(" attn_metadata:")
|
||||
logger.debug(" seq_lens_tensor: %s",
|
||||
attn_metadata.seq_lens_tensor)
|
||||
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
|
||||
logger.debug(" block_tables: %s", attn_metadata.block_tables)
|
||||
|
||||
return new_model_input
|
||||
|
||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||
Currently required conditions are:
|
||||
1. Only decodes
|
||||
2. Only flash-attn
|
||||
3. No LORA
|
||||
4. No prompt_adapter_config
|
||||
"""
|
||||
if not allow_gpu_advance_step:
|
||||
return False
|
||||
|
||||
# We allow multi-step GPU only in decode mode
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
if seq_group.is_prompt:
|
||||
return False
|
||||
|
||||
# TODO: Add support for ASCEND when outer multi_step_worker
|
||||
# could work correct.
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
if self.lora_config:
|
||||
return False
|
||||
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
return not self.prompt_adapter_config
|
||||
|
||||
def set_indices_of_seq_with_bonus_tokens(self,
|
||||
indices_of_seq_with_bonus_tokens):
|
||||
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelRunnerInputBase,
|
||||
kv_caches: List[torch.Tensor],
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
|
||||
Optimizations used:
|
||||
1. Input tensors are updated on the GPU directly
|
||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||
them since we do batch expansion later that uses GPU outputs)
|
||||
3. Reuses sampling tensors (since we run only decodes and they have
|
||||
a repeating sampling logic)
|
||||
"""
|
||||
|
||||
# When num_steps == 1, we execute the fallback here for the GPU
|
||||
# advance_step, which runs prepare_inputs on CPU and for each spec
|
||||
# iteration invokes this function only once
|
||||
# (Look at multi-step-worker code)
|
||||
is_fallback = num_steps == 1
|
||||
if not is_fallback:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
# Sanity
|
||||
if self.lora_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for LORA")
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
)
|
||||
else:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
if model_input.attn_metadata.num_prefills > 0:
|
||||
# In this case, execute_model(..) was called directly
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"execute_model(..) of draft_model_runner can be called "
|
||||
"directly only with a single-step prefill")
|
||||
else:
|
||||
# We can skip CPU samples for spec token generation.
|
||||
# (We do allow CPU samples for num_steps == 1 to support the
|
||||
# fallback case, where supports_gpu_multi_step(..) does not pass)
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
not is_fallback)
|
||||
|
||||
model_executable = self.model
|
||||
hidden_states = previous_hidden_states
|
||||
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
compute_logits_kwargs = {}
|
||||
# Run model
|
||||
if hasattr(self.model.config, "num_nextn_predict_layers"):
|
||||
# for DeepSeek MTP only to use the corresponding layer for
|
||||
# each step
|
||||
spec_step_idx = kwargs.get("spec_step_idx", step)
|
||||
model_execute_kwargs["spec_step_idx"] = spec_step_idx
|
||||
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
|
||||
if model_input.attn_metadata is not None:
|
||||
model_input.attn_metadata.input_positions = model_input.input_positions
|
||||
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**model_execute_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata,
|
||||
**compute_logits_kwargs)
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||
assert output.sampled_token_ids is not None
|
||||
# output.sampled_token_ids should be of shape (num_seqs, 1)
|
||||
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
|
||||
assert num_tokens_per_seq == 1
|
||||
count = 0
|
||||
for i in range(nums_seqs):
|
||||
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
|
||||
count]
|
||||
if i != bonus_seq_idx:
|
||||
# The following might cause a cpu->gpu sync
|
||||
# However, the performance impact is negligible as we
|
||||
# benchmarked on H100.
|
||||
output.sampled_token_ids[
|
||||
i, :] = model_input.input_tokens[bonus_seq_idx]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
model_input = self._gpu_advance_step(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
||||
Reference in New Issue
Block a user