Revert "Temporary fix invalid sample results" (#673)
This commit is contained in:
@@ -673,16 +673,6 @@ class Batch:
|
|||||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||||
probs, uniform_samples, self.top_ks, self.top_ps
|
probs, uniform_samples, self.top_ks, self.top_ps
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: This is a temporary fix for the illegal token ids in sampling.
|
|
||||||
illegal_mask = (
|
|
||||||
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
|
|
||||||
)
|
|
||||||
if torch.any(illegal_mask):
|
|
||||||
warnings.warn("Illegal token ids in sampling.")
|
|
||||||
batch_next_token_ids = torch.where(
|
|
||||||
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user