Move sampling logits to float32 (#773)
This commit is contained in:
@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module):
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
last_logits /= self.config.final_logit_softcapping
|
||||
@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size]
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
all_logprobs = all_logits.float()
|
||||
all_logprobs = all_logits
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
|
||||
@@ -687,13 +687,21 @@ class Batch:
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
if True:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
else:
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
||||
probs, self.top_ks, self.top_ps
|
||||
)
|
||||
|
||||
if torch.any(~success):
|
||||
if not torch.all(success):
|
||||
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
@@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||
max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||
):
|
||||
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError:
|
||||
batch_next_token_ids = torch.zeros(
|
||||
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
||||
)
|
||||
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
Reference in New Issue
Block a user