[Performance] Pre-issued exponential distribution operator. (#4908)

Pre-issued exponential distribution operator.

Result:
Single inference saves 200-300 microseconds.
before:

<img width="2257" height="1058" alt="2"
src="https://github.com/user-attachments/assets/c1da19e2-a439-42cb-9d7c-c0218e61fd4c"
/>

After:

<img width="2211" height="342" alt="image"
src="https://github.com/user-attachments/assets/03c84292-c802-4755-949c-4266a9a72fc0"
/>


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-11 23:02:51 +08:00
committed by GitHub
parent 0fbe0831ec
commit a6ef3ac4e4
3 changed files with 43 additions and 3 deletions

View File

@@ -1,13 +1,40 @@
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
global_stream, npu_stream_switch)
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-NPU synchronization.
"""
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
with npu_stream_switch(global_stream()):
q = torch.empty_like(probs)
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
torch.npu.current_stream().wait_stream(global_stream())
return probs.div_(q).argmax(dim=-1).view(-1)
class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):