From 0ac94c36cbc89c6b4b31a61779cb86982999211e Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sun, 21 Jul 2024 01:44:54 +0800 Subject: [PATCH] Fallback when sampling failed (#678) --- .../sglang/srt/managers/controller/infer_batch.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 0d50276cd..eda68ec46 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -668,18 +668,17 @@ class Batch: max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device) - batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( probs, uniform_samples, self.top_ks, self.top_ps ) - # FIXME: this is a temporary fix for the illegal token ids - illegal_mask = torch.logical_or( - batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1] - ) - if torch.any(illegal_mask): - warnings.warn("Illegal sampled token ids") + if torch.any(~success): + warnings.warn("Sampling failed, fallback to top_k=1 strategy") probs = probs.masked_fill(torch.isnan(probs), 0.0) - batch_next_token_ids = torch.argmax(probs, dim=-1) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + success, batch_next_token_ids, argmax_ids + ) if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()