[ops] support advanced apply_top_k_top_p without top_k constraint (#6098)

### What this PR does / why we need it?
Implement `apply_top_k_top_p` via ascendC to eliminate the constraint of
k [1,1024]. It enables high performance TopKTopP calculation and avoid
D2H synchronization introduced by k validation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
E2E serving with `k=4096` and  `p=0.95`
- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
linfeng-yuan
2026-01-26 09:08:42 +08:00
committed by GitHub
parent 4e3919e965
commit 96309e2b79
16 changed files with 2208 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import global_stream, npu_stream_switch
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
@@ -90,7 +90,7 @@ class AscendTopKTopPSampler(TopKTopPSampler):
return random_sample(probs, generators), logits_to_return
def apply_top_k_top_p(
def _apply_top_k_top_p_pytorch(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
@@ -124,3 +124,15 @@ def apply_top_k_top_p(
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def _apply_top_k_top_p_ascendc(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is None and k is None:
return logits
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p)
apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch