Sync changes on io_struct.py and deterministic ops (#11498)

This commit is contained in:
Lianmin Zheng
2025-10-12 16:03:10 -07:00
committed by GitHub
parent 0aa65f94f1
commit 2ac46e94ef
11 changed files with 73 additions and 25 deletions

View File

@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB:
@@ -288,21 +287,29 @@ def multinomial_with_seed(
"""
n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093
step_seed = (seed * 19349663) ^ (positions * 73856093)
seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
epsilon = 1e-10
uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
gumbel_noise = -torch.log(-torch.log(uniform_samples))
log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor):
def sampling_from_probs_torch(
probs: torch.Tensor,
sampling_seed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1)
if sampling_seed is not None:
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids