Minor follow-up fixes for the logprob refactor (#2670)

This commit is contained in:
Lianmin Zheng
2024-12-30 05:42:08 -08:00
committed by GitHub
parent c5210dfa38
commit 21ec66e59e
5 changed files with 11 additions and 12 deletions

View File

@@ -56,7 +56,9 @@ class Sampler(nn.Module):
if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)