Temporary fix invalid sample results (#668)
This commit is contained in:
@@ -673,6 +673,16 @@ class Batch:
|
||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
|
||||
# FIXME: This is a temporary fix for the illegal token ids in sampling.
|
||||
illegal_mask = (
|
||||
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
|
||||
)
|
||||
if torch.any(illegal_mask):
|
||||
warnings.warn("Illegal token ids in sampling.")
|
||||
batch_next_token_ids = torch.where(
|
||||
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user