[fix] Clamp logprob with dtype min to prevent -inf (#3224)
This commit is contained in:
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
|
|||||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||||
# so we use the torch implementation.
|
# so we use the torch implementation.
|
||||||
|
|
||||||
|
# clamp to avoid -inf
|
||||||
logprobs = torch.log(
|
logprobs = torch.log(
|
||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
)
|
).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
|
|||||||
sampling_info.need_min_p_sampling,
|
sampling_info.need_min_p_sampling,
|
||||||
)
|
)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
|
# clamp to avoid -inf
|
||||||
logprobs = torch.log(
|
logprobs = torch.log(
|
||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
)
|
).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
def run_decode(
|
def run_decode(
|
||||||
self,
|
self,
|
||||||
return_logprob=True,
|
return_logprob=True,
|
||||||
top_logprobs_num=3,
|
top_logprobs_num=5,
|
||||||
return_text=True,
|
return_text=True,
|
||||||
n=1,
|
n=1,
|
||||||
**sampling_params,
|
**sampling_params,
|
||||||
@@ -58,8 +58,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(json.dumps(response.json()))
|
assert response.status_code == 200, "Request failed: " + response.text
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
def test_default_values(self):
|
def test_default_values(self):
|
||||||
self.run_decode()
|
self.run_decode()
|
||||||
@@ -112,4 +111,4 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main(verbosity=3)
|
||||||
|
|||||||
Reference in New Issue
Block a user