[Perf] Use fused ops npu_top_k_top_p (#1308)

### What this PR does / why we need it?
Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are
not None, otherwise fallback to the original one. The replacement will
take place automatically when `VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1` .

This patch are using `npu_top_k_top_p` which required
torch_npu>=2.5.1.post1.dev20250619

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Tested by DeepSeek R1 and UT passed

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
Pr0Wh1teGivee
2025-06-25 20:59:06 +08:00
committed by GitHub
parent e7efc7e7e7
commit 2fda60464c
2 changed files with 34 additions and 1 deletions

View File

@@ -0,0 +1,28 @@
import importlib
import os
import unittest
from unittest import mock
import torch
from vllm.v1.sample.ops import topk_topp_sampler
class TestTopKTopPSamplerOptimize(unittest.TestCase):
@mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
import vllm_ascend.patch.worker.patch_common.patch_sampler
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = topk_topp_sampler.TopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)

View File

@@ -19,6 +19,7 @@
from typing import Optional
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler
@@ -48,9 +49,13 @@ def apply_min_p(
def _apply_top_k_top_p(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)