From 515267de22dd499a768fdcadae79f45e4a58ba49 Mon Sep 17 00:00:00 2001
From: linfeng-yuan <1102311262@qq.com>
Date: Wed, 24 Dec 2025 19:10:33 +0800
Subject: [PATCH] [perf][bugfix] improve performance of rejection sampler and
eliminate HD synchronize in TopKTopPSampler (#4154)
### What this PR does / why we need it?
1. Use optimized apply_top_k_top_p for NPU platfrom in rejection
sampler; (avoid scatter elements which can reduce ~26ms TPOT with bs=24
per DP)
2. Avoid D2H Synchronization before calling npu_top_k_top_p
introduced by parameter validation which improves inference speed with
`async_scheduling` enabled; In order to elminate the D2H
synchronization introduced by parameter validation before calling
`npu_top_k_top_p`, we directly drop this fused operator since the
performance improvement is not significant compared to async_scheduling
and may bring potential accuracy problem.
3. Refactor the implementation of AscendTopKTopPSampler to align that of
vLLM.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E serving test with combinations of `k=500` and `p=0.95` with
async_scheduling in single node and wide-EP scenarios.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu
---
tests/ut/sample/test_sampler.py | 24 -------
vllm_ascend/sample/rejection_sampler.py | 15 ++---
vllm_ascend/sample/sampler.py | 90 ++++++++++++-------------
3 files changed, 46 insertions(+), 83 deletions(-)
diff --git a/tests/ut/sample/test_sampler.py b/tests/ut/sample/test_sampler.py
index 682aa12a..3b58cf2d 100644
--- a/tests/ut/sample/test_sampler.py
+++ b/tests/ut/sample/test_sampler.py
@@ -1,7 +1,3 @@
-from unittest import mock
-
-import torch
-
from tests.ut.base import TestBase
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
@@ -13,23 +9,3 @@ class TestAscendSampler(TestBase):
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
-
-
-class TestAscendTopKTopPSampler(TestBase):
-
- @mock.patch("vllm_ascend.sample.sampler.random_sample")
- @mock.patch("torch_npu.npu_top_k_top_p")
- def test_npu_topk_topp_called_when_optimized(self, mock_npu_op,
- mock_random_sample):
- mock_npu_op.return_value = (torch.randn(1, 3))
- mock_random_sample.return_value = torch.randn(3)
- sampler = AscendTopKTopPSampler()
-
- logits = torch.tensor([[1.0, 2.0, 3.0]])
- k = torch.tensor([2])
- p = torch.tensor([0.9])
- generators = {0: torch.Generator()}
- generators[0].manual_seed(42)
-
- sampler.forward_native(logits, generators, k, p)
- mock_npu_op.assert_called_once_with(logits, p, k)
diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py
index 44bf7264..b0e6f848 100644
--- a/vllm_ascend/sample/rejection_sampler.py
+++ b/vllm_ascend/sample/rejection_sampler.py
@@ -2,13 +2,11 @@
from typing import Optional
import torch
-import torch_npu
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
-from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
-from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
+from vllm_ascend.sample.sampler import apply_top_k_top_p
PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
@@ -80,14 +78,9 @@ def apply_sampling_constraints(
num_tokens,
)
- if get_ascend_device_type(
- ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
- top_k.max()) <= 1024:
- return torch_npu.npu_top_k_top_p(logits, top_p.to(logits.dtype), top_k)
- else:
- # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
- # which is slow for large vocab sizes. This may cause performance issues.
- return apply_top_k_top_p(logits, top_k, top_p)
+ # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
+ # which is slow for large vocab sizes. This may cause performance issues.
+ return apply_top_k_top_p(logits, top_k, top_p)
def rejection_sample(
diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py
index 3d4fbe22..de043e95 100644
--- a/vllm_ascend/sample/sampler.py
+++ b/vllm_ascend/sample/sampler.py
@@ -1,11 +1,9 @@
import torch
-import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.ascend_config import get_ascend_config
-from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
- global_stream, npu_stream_switch)
+from vllm_ascend.utils import global_stream, npu_stream_switch
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
@@ -65,59 +63,19 @@ class AscendSampler(Sampler):
class AscendTopKTopPSampler(TopKTopPSampler):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.apply_top_k_top_p = apply_top_k_top_p
+
def set_q_event(self, q, event):
# Pass in async exponential results.
# Also pass in event to prevent synchronize errors.
self.q = q
self.async_event = event
- def _apply_top_k_top_p(
- self,
- logits: torch.Tensor,
- k: torch.Tensor,
- p: torch.Tensor,
- ) -> torch.Tensor:
- # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
- if get_ascend_device_type(
- ) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int(
- k.max()) <= 1024:
- # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
- return torch_npu.npu_top_k_top_p(logits, p, k)
-
- if p is None and k is None:
- return logits
-
- probs = logits.softmax(dim=-1)
- probs_sort, _ = probs.sort(dim=-1, descending=False)
-
- if k is not None:
- top_k_count = probs_sort.size(1) - k.to(
- torch.long) # shape: (batch, )
- top_k_count = top_k_count.unsqueeze(dim=1)
- top_k_cutoff = probs_sort.gather(-1, top_k_count)
-
- # Make sure the no top-k rows are no-op.
- no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
- top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
-
- elements_to_discard = probs < top_k_cutoff
- logits.masked_fill_(elements_to_discard, -float("inf"))
-
- if p is not None:
- cumprob = torch.cumsum(probs_sort, dim=-1)
- top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
- top_p_mask[:, -1] = False # at least one
-
- top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
- top_p_cutoff = probs_sort.gather(-1, top_p_count)
- elements_to_discard = probs < top_p_cutoff
- logits.masked_fill_(elements_to_discard, -float("inf"))
-
- return logits
-
def forward_native(self, logits, generators, k, p):
"""Override pytorch native implementation to torch_npu"""
- logits = self._apply_top_k_top_p(logits, k, p)
+ logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
@@ -130,3 +88,39 @@ class AscendTopKTopPSampler(TopKTopPSampler):
self.async_event.synchronize()
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
return random_sample(probs, generators), logits_to_return
+
+
+def apply_top_k_top_p(
+ logits: torch.Tensor,
+ k: torch.Tensor,
+ p: torch.Tensor,
+) -> torch.Tensor:
+ if p is None and k is None:
+ return logits
+
+ probs = logits.softmax(dim=-1)
+ probs_sort, _ = probs.sort(dim=-1, descending=False)
+
+ if k is not None:
+ top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
+ top_k_count = top_k_count.unsqueeze(dim=1)
+ top_k_cutoff = probs_sort.gather(-1, top_k_count)
+
+ # Make sure the no top-k rows are no-op.
+ no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
+ top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
+
+ elements_to_discard = probs < top_k_cutoff
+ logits.masked_fill_(elements_to_discard, -float("inf"))
+
+ if p is not None:
+ cumprob = torch.cumsum(probs_sort, dim=-1)
+ top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
+ top_p_mask[:, -1] = False # at least one
+
+ top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
+ top_p_cutoff = probs_sort.gather(-1, top_p_count)
+ elements_to_discard = probs < top_p_cutoff
+ logits.masked_fill_(elements_to_discard, -float("inf"))
+
+ return logits