[perf][bugfix] improve performance of rejection sampler and eliminate HD synchronize in TopKTopPSampler (#4154)
### What this PR does / why we need it?
1. Use optimized apply_top_k_top_p for NPU platfrom in rejection
sampler; (avoid scatter elements which can reduce ~26ms TPOT with bs=24
per DP)
2. <del>Avoid D2H Synchronization before calling npu_top_k_top_p
introduced by parameter validation which improves inference speed with
`async_scheduling` enabled;</del> In order to elminate the D2H
synchronization introduced by parameter validation before calling
`npu_top_k_top_p`, we directly drop this fused operator since the
performance improvement is not significant compared to async_scheduling
and may bring potential accuracy problem.
3. Refactor the implementation of AscendTopKTopPSampler to align that of
vLLM.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E serving test with combinations of `k=500` and `p=0.95` with
async_scheduling in single node and wide-EP scenarios.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -2,13 +2,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
||||
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = -1
|
||||
@@ -80,14 +78,9 @@ def apply_sampling_constraints(
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
if get_ascend_device_type(
|
||||
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
|
||||
top_k.max()) <= 1024:
|
||||
return torch_npu.npu_top_k_top_p(logits, top_p.to(logits.dtype), top_k)
|
||||
else:
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
|
||||
Reference in New Issue
Block a user