From 8bb028424bee2dc7e7d260103f0f62d356042141 Mon Sep 17 00:00:00 2001
From: FuNanyang <43992549+coder-fny@users.noreply.github.com>
Date: Wed, 10 Dec 2025 20:32:44 +0800
Subject: [PATCH] Fixed the performance degradation issue in post-processing in
speculative decoding scenarios. (#4849)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
…
### What this PR does / why we need it?
When speculative decoding is enabled and temperature > 0, bonus_logits
and target_logits are sampled separately:
1. bonus_logits are sampled using a fused torch_npu.npu_top_k_top_p
operator invoked inside the main sampler,
2. while target_logits are sampled within the rejection sampler using a
less-optimized implementation composed of smaller operators.
Consequently, the cumsum operation in the top-p sampling for
target_logits becomes especially time-consuming, leading to performance
degradation.
Apply the fused operator to the sampling of target_logits as well to
reduce overhead
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: funanyang <985619145@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
---
vllm_ascend/sample/rejection_sampler.py | 69 ++++++++++++++++++++++++-
1 file changed, 68 insertions(+), 1 deletion(-)
diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py
index b7905373..c1ef10db 100644
--- a/vllm_ascend/sample/rejection_sampler.py
+++ b/vllm_ascend/sample/rejection_sampler.py
@@ -3,14 +3,17 @@ from typing import Optional
import torch
import torch.nn as nn
+import torch_npu
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
- apply_sampling_constraints,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
+from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
+
PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
# Maximum number of speculative draft tokens allowed per request in a single
@@ -104,6 +107,70 @@ class AscendRejectionSampler(RejectionSampler, nn.Module):
return output_token_ids
+def apply_sampling_constraints(
+ logits: torch.Tensor, # [num_tokens, vocab_size]
+ cu_num_draft_tokens: torch.Tensor, # [batch_size]
+ sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+ """Process logits based on sampling metadata.
+
+ This function applies temperature scaling to the logits,
+ as well as top-k and top-p. For greedy decoding, it returns
+ the original logits.
+
+ Args:
+ logits: Input logits tensor to be processed.
+ cu_num_draft_tokens: Cumulative number of draft tokens.
+ sampling_metadata: Metadata containing sampling parameters such as
+ temperature and whether greedy sampling is used.
+
+ Returns:
+ torch.Tensor: Processed logits if non-greedy sampling is used,
+ otherwise returns the original logits.
+ """
+ assert logits.ndim == 2
+ assert cu_num_draft_tokens.ndim == 1
+ if sampling_metadata.all_greedy:
+ return logits
+
+ num_tokens = logits.shape[0]
+ temperature = expand_batch_to_tokens(
+ sampling_metadata.temperature,
+ cu_num_draft_tokens,
+ num_tokens,
+ replace_from=GREEDY_TEMPERATURE,
+ replace_to=1,
+ )
+ # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
+ logits.div_(temperature.unsqueeze(-1))
+
+ # Get expanded top_k and top_p tensors.
+ top_k = None
+ if sampling_metadata.top_k is not None:
+ top_k = expand_batch_to_tokens(
+ sampling_metadata.top_k,
+ cu_num_draft_tokens,
+ num_tokens,
+ )
+ top_p = None
+ if sampling_metadata.top_p is not None:
+ top_p = expand_batch_to_tokens(
+ sampling_metadata.top_p,
+ cu_num_draft_tokens,
+ num_tokens,
+ )
+
+ if get_ascend_device_type(
+ ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
+ top_k.max()) <= 1024:
+ return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16),
+ top_k)
+ else:
+ # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
+ # which is slow for large vocab sizes. This may cause performance issues.
+ return apply_top_k_top_p(logits, top_k, top_p)
+
+
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,