From a6ef3ac4e4433fd3069c9d76013659c41dfeb2d5 Mon Sep 17 00:00:00 2001
From: weijinqian0 <1184188277@qq.com>
Date: Thu, 11 Dec 2025 23:02:51 +0800
Subject: [PATCH] [Performance] Pre-issued exponential distribution operator.
(#4908)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Pre-issued exponential distribution operator.
Result:
Single inference saves 200-300 microseconds.
before:
After:
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: weijinqian_v1
Co-authored-by: weijinqian_v1
---
tests/ut/sample/test_sampler.py | 5 ++++-
vllm_ascend/sample/sampler.py | 31 +++++++++++++++++++++++++++++--
vllm_ascend/utils.py | 10 ++++++++++
3 files changed, 43 insertions(+), 3 deletions(-)
diff --git a/tests/ut/sample/test_sampler.py b/tests/ut/sample/test_sampler.py
index 98a83e6f..682aa12a 100644
--- a/tests/ut/sample/test_sampler.py
+++ b/tests/ut/sample/test_sampler.py
@@ -17,9 +17,12 @@ class TestAscendSampler(TestBase):
class TestAscendTopKTopPSampler(TestBase):
+ @mock.patch("vllm_ascend.sample.sampler.random_sample")
@mock.patch("torch_npu.npu_top_k_top_p")
- def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
+ def test_npu_topk_topp_called_when_optimized(self, mock_npu_op,
+ mock_random_sample):
mock_npu_op.return_value = (torch.randn(1, 3))
+ mock_random_sample.return_value = torch.randn(3)
sampler = AscendTopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py
index 6c9f37c6..1ea661cf 100644
--- a/vllm_ascend/sample/sampler.py
+++ b/vllm_ascend/sample/sampler.py
@@ -1,13 +1,40 @@
import torch
import torch_npu
-from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
+from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
-from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
+from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
+ global_stream, npu_stream_switch)
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
+def random_sample(
+ probs: torch.Tensor,
+ generators: dict[int, torch.Generator],
+) -> torch.Tensor:
+ """Randomly sample from the probabilities.
+
+ We use this function instead of torch.multinomial because torch.multinomial
+ causes CPU-NPU synchronization.
+ """
+ # NOTE(woosuk): To batch-process the requests without their own seeds,
+ # which is the common case, we first assume that every request does
+ # not have its own seed. Then, we overwrite the values for the requests
+ # that have their own seeds.
+ with npu_stream_switch(global_stream()):
+ q = torch.empty_like(probs)
+ if len(generators) != probs.shape[0]:
+ q.exponential_()
+ if generators:
+ # TODO(woosuk): This can be slow because we handle each request
+ # one by one. Optimize this.
+ for i, generator in generators.items():
+ q[i].exponential_(generator=generator)
+ torch.npu.current_stream().wait_stream(global_stream())
+ return probs.div_(q).argmax(dim=-1).view(-1)
+
+
class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index 0b3deb67..f69c6bcc 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -51,6 +51,7 @@ ACL_FORMAT_FRACTAL_NZ = 29
_CUSTOM_OP_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
+_GLOBAL_STREAM = None
_SHARED_EXPERTS_CALCULATION_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
@@ -292,6 +293,15 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM
+def global_stream() -> torch.npu.Stream:
+ global _GLOBAL_STREAM
+ if _GLOBAL_STREAM is None:
+ # when this function is called before any stream is set,
+ # we return the default stream.
+ _GLOBAL_STREAM = torch_npu.npu.Stream()
+ return _GLOBAL_STREAM
+
+
def shared_experts_calculation_stream() -> torch.npu.Stream:
global _SHARED_EXPERTS_CALCULATION_STREAM
if _SHARED_EXPERTS_CALCULATION_STREAM is None: