Sync changes on io_struct.py and deterministic ops (#11498)
This commit is contained in:
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if return_logprob:
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
else:
|
||||
# If requested, cache probabilities from original logits before temperature scaling.
|
||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||
@@ -288,21 +287,29 @@ def multinomial_with_seed(
|
||||
"""
|
||||
n, m = inputs.shape
|
||||
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
||||
step_seed = seed * 19349663 ^ positions * 73856093
|
||||
step_seed = (seed * 19349663) ^ (positions * 73856093)
|
||||
seed_expanded = step_seed.unsqueeze(-1)
|
||||
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
||||
hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
|
||||
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
||||
epsilon = 1e-9
|
||||
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
||||
epsilon = 1e-10
|
||||
uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
|
||||
gumbel_noise = -torch.log(-torch.log(uniform_samples))
|
||||
log_probs = torch.log(inputs + epsilon)
|
||||
perturbed_log_probs = log_probs + gumbel_noise
|
||||
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
||||
|
||||
|
||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||
def sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
sampling_seed: Optional[torch.Tensor] = None,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""A sampling implementation with native pytorch operations, without
|
||||
top-k, top-p, or min-p filtering."""
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
if sampling_seed is not None:
|
||||
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
|
||||
else:
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user