From c9f00c132c821e9264ca3efa23796ea3f08104ca Mon Sep 17 00:00:00 2001 From: yuqilinaa <38693908+yuqilinaa@users.noreply.github.com> Date: Tue, 20 Jan 2026 21:49:33 +0800 Subject: [PATCH] [Kernel] Enable fast random sample on Kunlun3 Platform with generators (#73) Co-authored-by: Xinyu Dong --- vllm_kunlun/v1/sample/ops/topk_topp_sampler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index 4aa5012..8d904e5 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -37,7 +37,7 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling. @@ -58,7 +58,7 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" if (k is None and p is None) or generators: if generators: @@ -158,14 +158,21 @@ def random_sample( if os.getenv('FAST_RANDOM_SAMPLE') == "1": q.uniform_() q = -torch.log(q) - q = q.clamp(min=1e-4) + q = q.clamp(min=1e-12) else: q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this. - for i, generator in generators.items(): - q[i].exponential_(generator=generator) + if os.getenv('FAST_RANDOM_SAMPLE') == "1": + for i, generator in generators.items(): + q[i].uniform_(generator=generator) + q = -torch.log(q) + q = q.clamp(min=1e-12) + else: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1)