Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
return_logprob: bool,
|
||||
top_logprobs_nums: List[int],
|
||||
):
|
||||
if isinstance(logits, LogitsProcessorOutput):
|
||||
logits = logits.next_token_logits
|
||||
|
||||
logits = logits.contiguous()
|
||||
logits = logits_output.next_token_logits
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if return_logprob:
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
@@ -54,6 +55,12 @@ class Sampler(nn.Module):
|
||||
del logits
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if return_logprob:
|
||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
|
||||
logprobs = torch.log(
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
)
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
@@ -76,6 +83,7 @@ class Sampler(nn.Module):
|
||||
if self.use_nan_detectioin and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
@@ -85,12 +93,31 @@ class Sampler(nn.Module):
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
if return_logprob:
|
||||
logprobs = torch.log(
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32)
|
||||
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
||||
|
||||
# Attach logprobs to logits_output (in-place modification)
|
||||
if return_logprob:
|
||||
if any(x > 0 for x in top_logprobs_nums):
|
||||
(
|
||||
logits_output.next_token_top_logprobs_val,
|
||||
logits_output.next_token_top_logprobs_idx,
|
||||
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
||||
|
||||
logits_output.next_token_logprobs = logprobs[
|
||||
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
||||
batch_next_token_ids,
|
||||
]
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_p_normalize_probs(
|
||||
def top_p_normalize_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
):
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
return top_p_renorm_prob(probs, top_ps)
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||
|
||||
|
||||
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
max_k = max(top_logprobs_nums)
|
||||
ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
output_top_logprobs_val = []
|
||||
output_top_logprobs_idx = []
|
||||
for i, k in enumerate(top_logprobs_nums):
|
||||
output_top_logprobs_val.append(values[i][:k])
|
||||
output_top_logprobs_idx.append(indices[i][:k])
|
||||
return output_top_logprobs_val, output_top_logprobs_idx
|
||||
|
||||
Reference in New Issue
Block a user