From f424e76d96e9cdc580cf648d7fdc75853a8530e1 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 20 Jul 2024 03:11:15 -0700 Subject: [PATCH] Fix illegal tokens during sampling (#676) --- docs/sampling_params.md | 2 +- .../srt/managers/controller/infer_batch.py | 24 +++++++++++-------- .../srt/managers/controller/model_runner.py | 5 ++-- python/sglang/srt/managers/io_struct.py | 2 +- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/docs/sampling_params.md b/docs/sampling_params.md index 979815ffb..6299c5953 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format. @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. - text: Union[List[str], str] + text: Optional[Union[List[str], str]] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The image input. It can be a file name, a url, or base64 encoded string. diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 5eed985d3..0d50276cd 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -665,16 +665,20 @@ class Batch: # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) - try: - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( - probs, uniform_samples, self.top_ks, self.top_ps - ) - except RuntimeError as e: - warnings.warn(f"Ignore errors in sampling: {e}") + + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device) + batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) + + # FIXME: this is a temporary fix for the illegal token ids + illegal_mask = torch.logical_or( + batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1] + ) + if torch.any(illegal_mask): + warnings.warn("Illegal sampled token ids") + probs = probs.masked_fill(torch.isnan(probs), 0.0) batch_next_token_ids = torch.argmax(probs, dim=-1) if has_regex: diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 34beebb3b..13ccc6041 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -246,12 +246,11 @@ class ModelRunner: self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list) ) - logger.info(f"Capture for batch sizes {batch_size_list}") try: self.cuda_graph_runner.capture(batch_size_list) - except: + except RuntimeError as e: raise Exception( - f"Capture cuda graph failed. Possible solutions:\n" + f"Capture cuda graph failed {e}. Possible solutions:\n" f"1. disable cuda graph by --disable-cuda-graph\n" f"2. set --mem-fraction-static to a smaller value\n" f"Open an issue on GitHub with reproducible scripts if you need help.\n" diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9638e12ca..8875994f1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. - text: Union[List[str], str] + text: Optional[Union[List[str], str]] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The image input. It can be a file name, a url, or base64 encoded string.