[Kernel] Enable fast random sample on Kunlun3 Platform with generators (#73)

Co-authored-by: Xinyu Dong <dongxinyu03@baidu.com>
This commit is contained in:
yuqilinaa
2026-01-20 21:49:33 +08:00
committed by GitHub
parent c404af3a41
commit c9f00c132c

View File

@@ -37,7 +37,7 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling.
@@ -58,7 +58,7 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""More optimized implementation for top-k and top-p sampling."""
if (k is None and p is None) or generators:
if generators:
@@ -158,14 +158,21 @@ def random_sample(
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
q.uniform_()
q = -torch.log(q)
q = q.clamp(min=1e-4)
q = q.clamp(min=1e-12)
else:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
if os.getenv('FAST_RANDOM_SAMPLE') == "1":
for i, generator in generators.items():
q[i].uniform_(generator=generator)
q = -torch.log(q)
q = q.clamp(min=1e-12)
else:
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)