diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 5eed985d3..3a9096170 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -673,6 +673,16 @@ class Batch: batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( probs, uniform_samples, self.top_ks, self.top_ps ) + + # FIXME: This is a temporary fix for the illegal token ids in sampling. + illegal_mask = ( + batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1] + ) + if torch.any(illegal_mask): + warnings.warn("Illegal token ids in sampling.") + batch_next_token_ids = torch.where( + illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids + ) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") batch_next_token_ids = torch.argmax(probs, dim=-1)