From a6ef3ac4e4433fd3069c9d76013659c41dfeb2d5 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Thu, 11 Dec 2025 23:02:51 +0800 Subject: [PATCH] [Performance] Pre-issued exponential distribution operator. (#4908) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-issued exponential distribution operator. Result: Single inference saves 200-300 microseconds. before: 2 After: image - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- tests/ut/sample/test_sampler.py | 5 ++++- vllm_ascend/sample/sampler.py | 31 +++++++++++++++++++++++++++++-- vllm_ascend/utils.py | 10 ++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/tests/ut/sample/test_sampler.py b/tests/ut/sample/test_sampler.py index 98a83e6f..682aa12a 100644 --- a/tests/ut/sample/test_sampler.py +++ b/tests/ut/sample/test_sampler.py @@ -17,9 +17,12 @@ class TestAscendSampler(TestBase): class TestAscendTopKTopPSampler(TestBase): + @mock.patch("vllm_ascend.sample.sampler.random_sample") @mock.patch("torch_npu.npu_top_k_top_p") - def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): + def test_npu_topk_topp_called_when_optimized(self, mock_npu_op, + mock_random_sample): mock_npu_op.return_value = (torch.randn(1, 3)) + mock_random_sample.return_value = torch.randn(3) sampler = AscendTopKTopPSampler() logits = torch.tensor([[1.0, 2.0, 3.0]]) diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 6c9f37c6..1ea661cf 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -1,13 +1,40 @@ import torch import torch_npu -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.sampler import Sampler -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + global_stream, npu_stream_switch) DEFAULT_LOGPROBS_MODE = "raw_logprobs" +def random_sample( + probs: torch.Tensor, + generators: dict[int, torch.Generator], +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-NPU synchronization. + """ + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + with npu_stream_switch(global_stream()): + q = torch.empty_like(probs) + if len(generators) != probs.shape[0]: + q.exponential_() + if generators: + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + torch.npu.current_stream().wait_stream(global_stream()) + return probs.div_(q).argmax(dim=-1).view(-1) + + class AscendSampler(Sampler): def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0b3deb67..f69c6bcc 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -51,6 +51,7 @@ ACL_FORMAT_FRACTAL_NZ = 29 _CUSTOM_OP_ENABLED = None _CURRENT_STREAM = None _PREFETCH_STREAM = None +_GLOBAL_STREAM = None _SHARED_EXPERTS_CALCULATION_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 @@ -292,6 +293,15 @@ def prefetch_stream() -> torch.npu.Stream: return _PREFETCH_STREAM +def global_stream() -> torch.npu.Stream: + global _GLOBAL_STREAM + if _GLOBAL_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _GLOBAL_STREAM = torch_npu.npu.Stream() + return _GLOBAL_STREAM + + def shared_experts_calculation_stream() -> torch.npu.Stream: global _SHARED_EXPERTS_CALCULATION_STREAM if _SHARED_EXPERTS_CALCULATION_STREAM is None: