Minor follow-up fixes for the logprob refactor (#2670)
This commit is contained in:
@@ -56,7 +56,9 @@ class Sampler(nn.Module):
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if return_logprob:
|
||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
|
||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||
# so we use the torch implementation.
|
||||
logprobs = torch.log(
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user