Fallback when sampling failed (#678)
This commit is contained in:
@@ -668,18 +668,17 @@ class Batch:
|
|||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
||||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||||
probs, uniform_samples, self.top_ks, self.top_ps
|
probs, uniform_samples, self.top_ks, self.top_ps
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: this is a temporary fix for the illegal token ids
|
if torch.any(~success):
|
||||||
illegal_mask = torch.logical_or(
|
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
||||||
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
|
||||||
)
|
|
||||||
if torch.any(illegal_mask):
|
|
||||||
warnings.warn("Illegal sampled token ids")
|
|
||||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
argmax_ids = torch.argmax(probs, dim=-1)
|
||||||
|
batch_next_token_ids = torch.where(
|
||||||
|
success, batch_next_token_ids, argmax_ids
|
||||||
|
)
|
||||||
|
|
||||||
if has_regex:
|
if has_regex:
|
||||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||||
|
|||||||
Reference in New Issue
Block a user