[Kernel] Enable fast random sample on Kunlun3 Platform with generators (#73)
Co-authored-by: Xinyu Dong <dongxinyu03@baidu.com>
This commit is contained in:
@@ -37,7 +37,7 @@ class TopKTopPSampler(nn.Module):
|
|||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
"""
|
"""
|
||||||
PyTorch-native implementation of top-k and top-p sampling.
|
PyTorch-native implementation of top-k and top-p sampling.
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ class TopKTopPSampler(nn.Module):
|
|||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
"""More optimized implementation for top-k and top-p sampling."""
|
"""More optimized implementation for top-k and top-p sampling."""
|
||||||
if (k is None and p is None) or generators:
|
if (k is None and p is None) or generators:
|
||||||
if generators:
|
if generators:
|
||||||
@@ -158,14 +158,21 @@ def random_sample(
|
|||||||
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
||||||
q.uniform_()
|
q.uniform_()
|
||||||
q = -torch.log(q)
|
q = -torch.log(q)
|
||||||
q = q.clamp(min=1e-4)
|
q = q.clamp(min=1e-12)
|
||||||
else:
|
else:
|
||||||
q.exponential_()
|
q.exponential_()
|
||||||
if generators:
|
if generators:
|
||||||
# TODO(woosuk): This can be slow because we handle each request
|
# TODO(woosuk): This can be slow because we handle each request
|
||||||
# one by one. Optimize this.
|
# one by one. Optimize this.
|
||||||
for i, generator in generators.items():
|
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
|
||||||
q[i].exponential_(generator=generator)
|
for i, generator in generators.items():
|
||||||
|
q[i].uniform_(generator=generator)
|
||||||
|
q = -torch.log(q)
|
||||||
|
q = q.clamp(min=1e-12)
|
||||||
|
else:
|
||||||
|
for i, generator in generators.items():
|
||||||
|
q[i].exponential_(generator=generator)
|
||||||
|
|
||||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user