Fallback when sampling failed (#678)
This commit is contained in:
@@ -668,18 +668,17 @@ class Batch:
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
|
||||
# FIXME: this is a temporary fix for the illegal token ids
|
||||
illegal_mask = torch.logical_or(
|
||||
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
||||
)
|
||||
if torch.any(illegal_mask):
|
||||
warnings.warn("Illegal sampled token ids")
|
||||
if torch.any(~success):
|
||||
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
|
||||
Reference in New Issue
Block a user