[Kernel] Enable fast random sample on Kunlun3 Platform with generators (#73)
Co-authored-by: Xinyu Dong <dongxinyu03@baidu.com>
This commit is contained in:
@@ -37,7 +37,7 @@ class TopKTopPSampler(nn.Module):
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
@@ -58,7 +58,7 @@ class TopKTopPSampler(nn.Module):
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
@@ -158,14 +158,21 @@ def random_sample(
|
||||
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
||||
q.uniform_()
|
||||
q = -torch.log(q)
|
||||
q = q.clamp(min=1e-4)
|
||||
q = q.clamp(min=1e-12)
|
||||
else:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
||||
for i, generator in generators.items():
|
||||
q[i].uniform_(generator=generator)
|
||||
q = -torch.log(q)
|
||||
q = q.clamp(min=1e-12)
|
||||
else:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user