Fixed the performance degradation issue in post-processing in speculative decoding scenarios. (#4849)

…

### What this PR does / why we need it?
When speculative decoding is enabled and temperature > 0, bonus_logits
and target_logits are sampled separately:
1. bonus_logits are sampled using a fused torch_npu.npu_top_k_top_p
operator invoked inside the main sampler,
2. while target_logits are sampled within the rejection sampler using a
less-optimized implementation composed of smaller operators.
Consequently, the cumsum operation in the top-p sampling for
target_logits becomes especially time-consuming, leading to performance
degradation.
<img width="1029" height="623" alt="image"
src="https://github.com/user-attachments/assets/1969f561-6aa5-41b3-9a87-1f64d4321cbf"
/>

Apply the fused operator to the sampling of target_logits as well to
reduce overhead
<img width="1039" height="572" alt="image"
src="https://github.com/user-attachments/assets/1e6563da-3418-405d-b657-7bbe10dd0924"
/>

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: funanyang <985619145@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
FuNanyang
2025-12-10 20:32:44 +08:00
committed by GitHub
parent 5b179c53f1
commit 8bb028424b

View File

@@ -3,14 +3,17 @@ from typing import Optional
import torch
import torch.nn as nn
import torch_npu
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
apply_sampling_constraints,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
# Maximum number of speculative draft tokens allowed per request in a single
@@ -104,6 +107,70 @@ class AscendRejectionSampler(RejectionSampler, nn.Module):
return output_token_ids
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Process logits based on sampling metadata.
This function applies temperature scaling to the logits,
as well as top-k and top-p. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be processed.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Processed logits if non-greedy sampling is used,
otherwise returns the original logits.
"""
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy:
return logits
num_tokens = logits.shape[0]
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
num_tokens,
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
if get_ascend_device_type(
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
top_k.max()) <= 1024:
return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16),
top_k)
else:
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
return apply_top_k_top_p(logits, top_k, top_p)
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,