diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index 4aa5012..8d904e5 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -37,7 +37,7 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling. @@ -58,7 +58,7 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" if (k is None and p is None) or generators: if generators: @@ -158,14 +158,21 @@ def random_sample( if os.getenv('FAST_RANDOM_SAMPLE') == "1": q.uniform_() q = -torch.log(q) - q = q.clamp(min=1e-4) + q = q.clamp(min=1e-12) else: q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this. - for i, generator in generators.items(): - q[i].exponential_(generator=generator) + if os.getenv('FAST_RANDOM_SAMPLE') == "1": + for i, generator in generators.items(): + q[i].uniform_(generator=generator) + q = -torch.log(q) + q = q.clamp(min=1e-12) + else: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1)