Move sampling logits to float32 (#773)
This commit is contained in:
@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
last_logits = last_logits[:, : self.config.vocab_size]
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if hasattr(self.config, "final_logit_softcapping"):
|
if hasattr(self.config, "final_logit_softcapping"):
|
||||||
last_logits /= self.config.final_logit_softcapping
|
last_logits /= self.config.final_logit_softcapping
|
||||||
@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logits = torch.matmul(hidden_states, weight.T)
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
all_logits = all_logits[:, : self.config.vocab_size]
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
all_logprobs = all_logits.float()
|
all_logprobs = all_logits
|
||||||
del all_logits
|
del all_logits
|
||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
|
|||||||
@@ -687,13 +687,21 @@ class Batch:
|
|||||||
# TODO(lmzheng): apply penalty
|
# TODO(lmzheng): apply penalty
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
if True:
|
||||||
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
uniform_samples = torch.rand(
|
||||||
probs, uniform_samples, self.top_ks, self.top_ps
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
)
|
)
|
||||||
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, self.top_ks, self.top_ps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Here we provide a slower fallback implementation.
|
||||||
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
||||||
|
probs, self.top_ks, self.top_ps
|
||||||
|
)
|
||||||
|
|
||||||
if torch.any(~success):
|
if not torch.all(success):
|
||||||
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
||||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||||
argmax_ids = torch.argmax(probs, dim=-1)
|
argmax_ids = torch.argmax(probs, dim=-1)
|
||||||
@@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|||||||
max_extend_len = int(torch.max(extend_seq_lens))
|
max_extend_len = int(torch.max(extend_seq_lens))
|
||||||
|
|
||||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_top_p_sampling_from_probs_torch(
|
||||||
|
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||||
|
):
|
||||||
|
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
||||||
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||||
|
probs_sort[
|
||||||
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||||
|
>= top_ks.view(-1, 1)
|
||||||
|
] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
|
try:
|
||||||
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
except RuntimeError:
|
||||||
|
batch_next_token_ids = torch.zeros(
|
||||||
|
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
||||||
|
)
|
||||||
|
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||||
|
return batch_next_token_ids, success
|
||||||
|
|
||||||
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||||
|
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||||
|
return batch_next_token_ids, success
|
||||||
|
|||||||
Reference in New Issue
Block a user