Merge pull request #3 from xyDong0223/main
[Kernel] Enable fast random sample on Kunlun3 Platform
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -150,7 +150,12 @@ def random_sample(
|
|||||||
# not have its own seed. Then, we overwrite the values for the requests
|
# not have its own seed. Then, we overwrite the values for the requests
|
||||||
# that have their own seeds.
|
# that have their own seeds.
|
||||||
if len(generators) != probs.shape[0]:
|
if len(generators) != probs.shape[0]:
|
||||||
q.exponential_()
|
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
||||||
|
q.uniform_()
|
||||||
|
q = -torch.log(q)
|
||||||
|
q = q.clamp(min=1e-4)
|
||||||
|
else:
|
||||||
|
q.exponential_()
|
||||||
if generators:
|
if generators:
|
||||||
# TODO(woosuk): This can be slow because we handle each request
|
# TODO(woosuk): This can be slow because we handle each request
|
||||||
# one by one. Optimize this.
|
# one by one. Optimize this.
|
||||||
|
|||||||
Reference in New Issue
Block a user