diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index d22f4edb9..e19ec5897 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -376,7 +376,7 @@ class Batch: logit_bias = torch.zeros( (bs, vocab_size), dtype=torch.float32, device=device ) - logit_bias[i] = int_token_logit_bias + logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias # Set fields self.input_ids = torch.tensor( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f6cc8677c..34890d699 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -133,24 +133,10 @@ class TokenizerManager: async for response in self._handle_batch_request(obj, request): yield response - async def _handle_single_request(self, obj, request, index=None, is_prefill=False): - if is_prefill: - if isinstance(obj.text, list): - input_text = obj.text[index] - rid = obj.rid[index] - else: - input_text = obj.text - rid = obj.rid[0] - input_ids = self.tokenizer.encode(input_text) - sampling_params = SamplingParams(**obj.sampling_params[0]) - sampling_params.max_new_tokens = 0 - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data[0] - ) - return_logprob = obj.return_logprob[0] - logprob_start_len = obj.logprob_start_len[0] - top_logprobs_num = obj.top_logprobs_num[0] - else: + async def _handle_single_request( + self, obj, request, index=None, is_cache_for_prefill=False + ): + if not is_cache_for_prefill: rid = obj.rid if index is None else obj.rid[index] input_text = obj.text if index is None else obj.text[index] input_ids = ( @@ -177,6 +163,22 @@ class TokenizerManager: top_logprobs_num = ( obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] ) + else: + if isinstance(obj.text, list): + input_text = obj.text[index] + rid = obj.rid[index] + else: + input_text = obj.text + rid = obj.rid[0] + input_ids = self.tokenizer.encode(input_text) + sampling_params = SamplingParams(**obj.sampling_params[0]) + sampling_params.max_new_tokens = 0 + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[0] + ) + return_logprob = obj.return_logprob[0] + logprob_start_len = obj.logprob_start_len[0] + top_logprobs_num = obj.top_logprobs_num[0] tokenized_obj = TokenizedGenerateReqInput( rid, @@ -196,26 +198,26 @@ class TokenizerManager: event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state - if is_prefill: - await self._wait_for_prefill_response(event, state, obj, request, rid) - yield input_ids - else: + if not is_cache_for_prefill: async for response in self._wait_for_response( event, state, obj, rid, request ): yield response + else: + await self._wait_for_cache_prefill_response(event, state, obj, rid, request) + yield input_ids - async def _handle_batch_request(self, obj, request): + async def _handle_batch_request(self, obj: GenerateReqInput, request): batch_size = obj.batch_size parallel_sample_num = obj.sampling_params[0].get("n", 1) if parallel_sample_num != 1: - ## send prefill requests + # Send prefill requests to cache the common input parallel_sample_num += 1 input_id_result = [] if obj.input_ids is None else None for i in range(batch_size): async for input_id in self._handle_single_request( - obj, request, index=i, is_prefill=True + obj, request, index=i, is_cache_for_prefill=True ): if input_id_result is not None: input_id_result.append(input_id) @@ -224,6 +226,7 @@ class TokenizerManager: obj.input_ids = input_id_result elif input_id_result is not None: obj.input_ids = input_id_result[0] + # First send out all requests for i in range(batch_size): for j in range(parallel_sample_num): @@ -308,17 +311,15 @@ class TokenizerManager: yield output_list - def _validate_input_length(self, input_ids): + def _validate_input_length(self, input_ids: List[int]): if len(input_ids) >= self.context_len: raise ValueError( f"The input ({len(input_ids)} tokens) is longer than the " f"model's context length ({self.context_len} tokens)." ) - def _get_sampling_params(self, sampling_params_data, max_new_tokens=None): + def _get_sampling_params(self, sampling_params_data: dict): sampling_params = SamplingParams(**sampling_params_data) - if max_new_tokens is not None: - sampling_params.max_new_tokens = max_new_tokens if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) sampling_params.verify() @@ -332,7 +333,14 @@ class TokenizerManager: else: return None, None, None - async def _wait_for_response(self, event, state, obj, rid, request): + async def _wait_for_response( + self, + event: asyncio.Event, + state: ReqState, + obj: GenerateReqInput, + rid: str, + request, + ): while True: try: await asyncio.wait_for(event.wait(), timeout=4) @@ -361,7 +369,14 @@ class TokenizerManager: event.clear() yield out - async def _wait_for_prefill_response(self, event, state, obj, request, rid): + async def _wait_for_cache_prefill_response( + self, + event: asyncio.Event, + state: ReqState, + obj: GenerateReqInput, + rid: str, + request, + ): while True: try: await asyncio.wait_for(state.event.wait(), timeout=4) @@ -380,7 +395,7 @@ class TokenizerManager: req = FlushCacheReq() self.send_to_router.send_pyobj(req) - def abort_request(self, rid): + def abort_request(self, rid: str): if rid not in self.rid_to_state: return del self.rid_to_state[rid] @@ -426,7 +441,11 @@ class TokenizerManager: state.event.set() def convert_logprob_style( - self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs + self, + ret: dict, + return_logprob: bool, + top_logprobs_num: int, + return_text_in_logprobs: bool, ): if return_logprob: ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( @@ -450,7 +469,7 @@ class TokenizerManager: ) return ret - def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): + def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool): if not decode_to_text: return [(logprob, token_id, None) for logprob, token_id in token_logprobs] @@ -461,7 +480,7 @@ class TokenizerManager: for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) ] - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text): + def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): for i, t in enumerate(top_logprobs): if t: top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index c9e8139df..9ba794ac9 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -118,7 +118,11 @@ def test_decode_json_regex(): s += "}" ret = decode_json.run() - js_obj = json.loads(ret["json_output"]) + try: + js_obj = json.loads(ret["json_output"]) + except json.decoder.JSONDecodeError: + print(ret["json_output"]) + raise assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int)