Temporary fix invalid sample results (#668)

This commit is contained in:
Liangsheng Yin
2024-07-20 00:51:05 -07:00
committed by GitHub
parent e3046ea3a8
commit 8f4b1559e7

View File

@@ -673,6 +673,16 @@ class Batch:
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
# FIXME: This is a temporary fix for the illegal token ids in sampling.
illegal_mask = (
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal token ids in sampling.")
batch_next_token_ids = torch.where(
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
batch_next_token_ids = torch.argmax(probs, dim=-1)