Files
xc-llm-ascend/vllm_ascend/sample/sampler.py
weijinqian0 a6ef3ac4e4 [Performance] Pre-issued exponential distribution operator. (#4908)
Pre-issued exponential distribution operator.

Result:
Single inference saves 200-300 microseconds.
before:

<img width="2257" height="1058" alt="2"
src="https://github.com/user-attachments/assets/c1da19e2-a439-42cb-9d7c-c0218e61fd4c"
/>

After:

<img width="2211" height="342" alt="image"
src="https://github.com/user-attachments/assets/03c84292-c802-4755-949c-4266a9a72fc0"
/>


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-12-11 23:02:51 +08:00

103 lines
3.9 KiB
Python

import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
global_stream, npu_stream_switch)
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-NPU synchronization.
"""
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
with npu_stream_switch(global_stream()):
q = torch.empty_like(probs)
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
torch.npu.current_stream().wait_stream(global_stream())
return probs.div_(q).argmax(dim=-1).view(-1)
class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
self.topk_topp_sampler = AscendTopKTopPSampler()
class AscendTopKTopPSampler(TopKTopPSampler):
def _apply_top_k_top_p(
self,
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
if get_ascend_device_type(
) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int(
k.max()) <= 1024:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)
if p is None and k is None:
return logits
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(
torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def forward_native(self, logits, generators, k, p):
"""Override pytorch native implementation to torch_npu"""
logits = self._apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return