Fixed the performance degradation issue in post-processing in speculative decoding scenarios. (#4849)
…
### What this PR does / why we need it?
When speculative decoding is enabled and temperature > 0, bonus_logits
and target_logits are sampled separately:
1. bonus_logits are sampled using a fused torch_npu.npu_top_k_top_p
operator invoked inside the main sampler,
2. while target_logits are sampled within the rejection sampler using a
less-optimized implementation composed of smaller operators.
Consequently, the cumsum operation in the top-p sampling for
target_logits becomes especially time-consuming, leading to performance
degradation.
<img width="1029" height="623" alt="image"
src="https://github.com/user-attachments/assets/1969f561-6aa5-41b3-9a87-1f64d4321cbf"
/>
Apply the fused operator to the sampling of target_logits as well to
reduce overhead
<img width="1039" height="572" alt="image"
src="https://github.com/user-attachments/assets/1e6563da-3418-405d-b657-7bbe10dd0924"
/>
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: funanyang <985619145@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -3,14 +3,17 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.v1.sample.rejection_sampler as rs
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
|
||||
apply_sampling_constraints,
|
||||
generate_uniform_probs)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
@@ -104,6 +107,70 @@ class AscendRejectionSampler(RejectionSampler, nn.Module):
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Process logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits,
|
||||
as well as top-k and top-p. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be processed.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed logits if non-greedy sampling is used,
|
||||
otherwise returns the original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
temperature = expand_batch_to_tokens(
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||
logits.div_(temperature.unsqueeze(-1))
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
if sampling_metadata.top_k is not None:
|
||||
top_k = expand_batch_to_tokens(
|
||||
sampling_metadata.top_k,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
top_p = None
|
||||
if sampling_metadata.top_p is not None:
|
||||
top_p = expand_batch_to_tokens(
|
||||
sampling_metadata.top_p,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
if get_ascend_device_type(
|
||||
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
|
||||
top_k.max()) <= 1024:
|
||||
return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16),
|
||||
top_k)
|
||||
else:
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user