diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index 64516ff..db5fa2f 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional - +import os import torch import torch.nn as nn from packaging import version @@ -150,7 +150,12 @@ def random_sample( # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. if len(generators) != probs.shape[0]: - q.exponential_() + if os.getenv('FAST_RANDOM_SAMPLE') == "1": + q.uniform_() + q = -torch.log(q) + q = q.clamp(min=1e-4) + else: + q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this.