[perf][bugfix] improve performance of rejection sampler and eliminate HD synchronize in TopKTopPSampler (#4154)

### What this PR does / why we need it?
1. Use optimized apply_top_k_top_p for NPU platfrom in rejection
sampler; (avoid scatter elements which can reduce ~26ms TPOT with bs=24
per DP)
2. <del>Avoid D2H Synchronization before calling npu_top_k_top_p
introduced by parameter validation which improves inference speed with
`async_scheduling` enabled;</del> In order to elminate the D2H
synchronization introduced by parameter validation before calling
`npu_top_k_top_p`, we directly drop this fused operator since the
performance improvement is not significant compared to async_scheduling
and may bring potential accuracy problem.
3. Refactor the implementation of AscendTopKTopPSampler to align that of
vLLM.

### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E serving test with combinations of `k=500` and `p=0.95` with
async_scheduling in single node and wide-EP scenarios.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
linfeng-yuan
2025-12-24 19:10:33 +08:00
committed by GitHub
parent 2f03a2f4a4
commit 515267de22
3 changed files with 46 additions and 83 deletions

View File

@@ -1,7 +1,3 @@
from unittest import mock
import torch
from tests.ut.base import TestBase
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
@@ -13,23 +9,3 @@ class TestAscendSampler(TestBase):
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
class TestAscendTopKTopPSampler(TestBase):
@mock.patch("vllm_ascend.sample.sampler.random_sample")
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op,
mock_random_sample):
mock_npu_op.return_value = (torch.randn(1, 3))
mock_random_sample.return_value = torch.randn(3)
sampler = AscendTopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)