From 515267de22dd499a768fdcadae79f45e4a58ba49 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Wed, 24 Dec 2025 19:10:33 +0800 Subject: [PATCH] [perf][bugfix] improve performance of rejection sampler and eliminate HD synchronize in TopKTopPSampler (#4154) ### What this PR does / why we need it? 1. Use optimized apply_top_k_top_p for NPU platfrom in rejection sampler; (avoid scatter elements which can reduce ~26ms TPOT with bs=24 per DP) 2. Avoid D2H Synchronization before calling npu_top_k_top_p introduced by parameter validation which improves inference speed with `async_scheduling` enabled; In order to elminate the D2H synchronization introduced by parameter validation before calling `npu_top_k_top_p`, we directly drop this fused operator since the performance improvement is not significant compared to async_scheduling and may bring potential accuracy problem. 3. Refactor the implementation of AscendTopKTopPSampler to align that of vLLM. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? E2E serving test with combinations of `k=500` and `p=0.95` with async_scheduling in single node and wide-EP scenarios. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: realliujiaxu --- tests/ut/sample/test_sampler.py | 24 ------- vllm_ascend/sample/rejection_sampler.py | 15 ++--- vllm_ascend/sample/sampler.py | 90 ++++++++++++------------- 3 files changed, 46 insertions(+), 83 deletions(-) diff --git a/tests/ut/sample/test_sampler.py b/tests/ut/sample/test_sampler.py index 682aa12a..3b58cf2d 100644 --- a/tests/ut/sample/test_sampler.py +++ b/tests/ut/sample/test_sampler.py @@ -1,7 +1,3 @@ -from unittest import mock - -import torch - from tests.ut.base import TestBase from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler @@ -13,23 +9,3 @@ class TestAscendSampler(TestBase): self.assertEqual(sampler.logprobs_mode, "raw_logprobs") self.assertTrue(hasattr(sampler, 'topk_topp_sampler')) self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler) - - -class TestAscendTopKTopPSampler(TestBase): - - @mock.patch("vllm_ascend.sample.sampler.random_sample") - @mock.patch("torch_npu.npu_top_k_top_p") - def test_npu_topk_topp_called_when_optimized(self, mock_npu_op, - mock_random_sample): - mock_npu_op.return_value = (torch.randn(1, 3)) - mock_random_sample.return_value = torch.randn(3) - sampler = AscendTopKTopPSampler() - - logits = torch.tensor([[1.0, 2.0, 3.0]]) - k = torch.tensor([2]) - p = torch.tensor([0.9]) - generators = {0: torch.Generator()} - generators[0].manual_seed(42) - - sampler.forward_native(logits, generators, k, p) - mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 44bf7264..b0e6f848 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -2,13 +2,11 @@ from typing import Optional import torch -import torch_npu from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.rejection_sampler import generate_uniform_probs -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type +from vllm_ascend.sample.sampler import apply_top_k_top_p PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = -1 @@ -80,14 +78,9 @@ def apply_sampling_constraints( num_tokens, ) - if get_ascend_device_type( - ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int( - top_k.max()) <= 1024: - return torch_npu.npu_top_k_top_p(logits, top_p.to(logits.dtype), top_k) - else: - # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, - # which is slow for large vocab sizes. This may cause performance issues. - return apply_top_k_top_p(logits, top_k, top_p) + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. + return apply_top_k_top_p(logits, top_k, top_p) def rejection_sample( diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 3d4fbe22..de043e95 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -1,11 +1,9 @@ import torch -import torch_npu from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.sampler import Sampler from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - global_stream, npu_stream_switch) +from vllm_ascend.utils import global_stream, npu_stream_switch DEFAULT_LOGPROBS_MODE = "raw_logprobs" @@ -65,59 +63,19 @@ class AscendSampler(Sampler): class AscendTopKTopPSampler(TopKTopPSampler): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.apply_top_k_top_p = apply_top_k_top_p + def set_q_event(self, q, event): # Pass in async exponential results. # Also pass in event to prevent synchronize errors. self.q = q self.async_event = event - def _apply_top_k_top_p( - self, - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, - ) -> torch.Tensor: - # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P - if get_ascend_device_type( - ) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int( - k.max()) <= 1024: - # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) - return torch_npu.npu_top_k_top_p(logits, p, k) - - if p is None and k is None: - return logits - - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to( - torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits - def forward_native(self, logits, generators, k, p): """Override pytorch native implementation to torch_npu""" - logits = self._apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -130,3 +88,39 @@ class AscendTopKTopPSampler(TopKTopPSampler): self.async_event.synchronize() return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return return random_sample(probs, generators), logits_to_return + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + if p is None and k is None: + return logits + + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits