### What this PR does / why we need it? Backport: https://github.com/vllm-project/vllm-ascend/pull/252 This support speculative decoding in Ascend, including speculating with a draft model、by matching n-grams in the prompt、using MLP speculators and using EAGLE based draft models. Backport: https://github.com/vllm-project/vllm-ascend/pull/423 spec decode MultiStepWorker support TP1DraftModelRunner fully, support run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA. 1. before this pr, `MultiStepWorker` would not step into the branch using NPU prepare, but only into the branch using CPU prepare (`line 52` of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has `no effect` on the `correct operation` of speculative decoding and the performance of the two branches is basically the same as of the current version, I support entering this branch in this PR. In general, there are two main changes in `patch_multi_step_worker.py`: first, the `is_cuda_like()` check is removed and the `TP1DraftModelRunner` rewritten in vllm_ascend is used; second, the `supports_gpu_multi_step()` function is made to return true on NPU devices when outer Multi_step_worker could work correct. 3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU, but not MLA. The relevant adaptation is in `vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why the `input_positions` of `model_input.attn_metadata` in vllm-ascend needs to be added in `execute_model`, it is done in `model_runner.py`, so I also made corresponding changes. Otherwise, when atten_backend is MLA, it will prompt that input_positions cannot be found. 4. I commented out two lines in `draft_model_runner.py` in `line118` to support the scenario of K>1. ``` # lora_mapping=model_input.lora_mapping, # lora_requests=model_input.lora_requests, ``` I added comments. In the future, when vllm-ascend supports lora feature, the changes here can be restored. TODO: - [ ] revert the patch when the related issues are addressed in vllm ### How was this patch tested? CI passed with new added test. - e2e test for medusa proposer: tests/singlecard/spec_decode/e2e/test_medusa_correctness.py - e2e test for mlp proposer: tests/singlecard/spec_decode/e2e/test_mlp_correctness.py - e2e test for n-gram proposer: tests/singlecard/spec_decode/e2e/test_ngram_correctness.py Tests for patched files: - tests/singlecard/spec_decode/test_dynamic_spec_decode.py - tests/singlecard/spec_decode/test_multi_step_worker.py - tests/singlecard/spec_decode/test_ngram_worker.py - tests/singlecard/spec_decode/test_spec_decode_worker.py --------- Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: mengwei805 <mengwei25@huawei.com>
89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
|
SpecDecodeWorkerMetrics)
|
|
|
|
Timer = Callable[[], float]
|
|
|
|
# TODO: revert this patch when the cuda hard code is removed in vllm
|
|
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
|
|
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
|
|
|
|
|
|
def init_tensors(self,
|
|
rank: int,
|
|
device_type: Union[torch.device, str] = 'npu') -> None:
|
|
self._rank = rank
|
|
if isinstance(device_type, torch.device):
|
|
device_type = device_type.type
|
|
if device_type == 'npu':
|
|
self._copy_stream = torch_npu.npu.Stream()
|
|
|
|
|
|
def maybe_collect_rejsample_metrics(
|
|
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
|
|
|
# If a copy was initiated in the previous call, collect and return.
|
|
if self._in_flight_copy is not None:
|
|
ready_event = self._in_flight_copy
|
|
self._in_flight_copy = None
|
|
return self._collect_rejsample_metrics(k, ready_event)
|
|
|
|
# Otherwise, check if we should start a new copy.
|
|
if self._should_collect_rejsample_metrics(self._timer()):
|
|
assert self._in_flight_copy is None
|
|
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
|
|
|
return None
|
|
|
|
|
|
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
|
|
"""
|
|
TODO: torch.cuda.xxx --> torch.npu.xxx
|
|
Copy rejection/typical-acceptance sampling metrics
|
|
(number of accepted tokens, etc) to CPU asynchronously.
|
|
|
|
Returns a NPU event recording when the copy is complete.
|
|
"""
|
|
assert self._copy_stream is not None
|
|
self._copy_stream.wait_stream(torch.npu.current_stream())
|
|
|
|
with torch.npu.stream(self._copy_stream):
|
|
self._aggregate_num_accepted_tokens.copy_(
|
|
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
|
|
self._aggregate_num_emitted_tokens.copy_(
|
|
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
|
# Number of draft tokens is calculated on CPU, so no copy is
|
|
# required.
|
|
self._aggregate_num_draft_tokens = (
|
|
self.spec_decode_sampler.num_draft_tokens)
|
|
|
|
aggregate_metrics_ready = torch.npu.Event()
|
|
aggregate_metrics_ready.record(self._copy_stream)
|
|
|
|
return aggregate_metrics_ready
|
|
|
|
|
|
AsyncMetricsCollector.init_tensors = init_tensors
|
|
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
|
|
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async
|