[OAI] Support non-normalized logprobs in OpenAI server (#5961)
This commit is contained in:
@@ -86,11 +86,9 @@ class Sampler(nn.Module):
|
|||||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||||
# so we use the torch implementation.
|
# so we use the torch implementation.
|
||||||
|
# NOTE: OpenAI's logprobs is independent of top-p, we use the
|
||||||
# clamp to avoid -inf
|
# same rule.
|
||||||
logprobs = torch.log(
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
|
||||||
).clamp(min=torch.finfo(probs.dtype).min)
|
|
||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
if sampling_info.need_min_p_sampling:
|
if sampling_info.need_min_p_sampling:
|
||||||
@@ -121,10 +119,7 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
# clamp to avoid -inf
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
logprobs = torch.log(
|
|
||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
|
||||||
).clamp(min=torch.finfo(probs.dtype).min)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
|
|||||||
Reference in New Issue
Block a user