diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index e175040..e45c426 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -151,7 +151,12 @@ def random_sample( # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. if len(generators) != probs.shape[0]: - q.exponential_() + if os.getenv('FAST_RANDOM_SAMPLE') == "1": + q.uniform_() + q = -torch.log(q) + q = q.clamp(min=1e-4) + else: + q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this.