[fix] Clamp logprob with dtype min to prevent -inf (#3224)
This commit is contained in:
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
|
||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||
# so we use the torch implementation.
|
||||
|
||||
# clamp to avoid -inf
|
||||
logprobs = torch.log(
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
)
|
||||
).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
if return_logprob:
|
||||
# clamp to avoid -inf
|
||||
logprobs = torch.log(
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
)
|
||||
).clamp(min=torch.finfo(probs.dtype).min)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
|
||||
Reference in New Issue
Block a user