diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index a58c98851..bdbf33925 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -30,7 +30,7 @@ if __name__ == "__main__": response = requests.post( url + "/generate", json={ - "text": f"{a}, ", + "input_ids": [[1,2,3], [1,2,3]], "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, diff --git a/docs/sampling_params.md b/docs/sampling_params.md index add849d21..065bbc2d5 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -8,6 +8,8 @@ The `/generate` endpoint accepts the following arguments in the JSON format. class GenerateReqInput: # The input prompt text: Union[List[str], str] + # The token ids for text; one can either specify text or input_ids + input_ids: Optional[Union[List[List[int]], List[int]]] = None # The image input image_data: Optional[Union[List[str], str]] = None # The sampling_params diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6e64380c9..db9655f64 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams @dataclass class GenerateReqInput: # The input prompt - text: Union[List[str], str] + text: Optional[Union[List[str], str]] = None + # The token ids for text; one can either specify text or input_ids + input_ids: Optional[Union[List[List[int]], List[int]]] = None # The image input image_data: Optional[Union[List[str], str]] = None # The sampling_params @@ -28,7 +30,17 @@ class GenerateReqInput: # TODO: make all parameters a Union[List[T], T] to allow for batched requests def post_init(self): - is_single = isinstance(self.text, str) + + if self.text is None: + assert self.input_ids is not None, "Either text or input_ids should be provided" + else: + assert self.input_ids is None, "Either text or input_ids should be provided" + + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) + self.is_single = is_single if is_single: if self.sampling_params is None: @@ -42,7 +54,7 @@ class GenerateReqInput: if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - num = len(self.text) + num = len(self.text) if self.text is not None else len(self.input_ids) if self.image_data is None: self.image_data = [None] * num diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 1f655513a..a46e1e9db 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -85,6 +85,9 @@ class Req: ) if first_token.startswith("▁"): old_output_str = " " + old_output_str + if self.input_text is None: + # TODO(lmzheng): This can be wrong. Check with Liangsheng. + self.input_text = self.tokenizer.decode(self.input_ids) new_input_string = ( self.input_text + self.output_and_jump_forward_str diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 44e3c99e0..e8aa2d77c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -147,11 +147,15 @@ class TokenizerManager: if self.to_create_loop: await self.create_handle_loop() - is_single = isinstance(obj.text, str) - + is_single = obj.is_single if is_single: rid = obj.rid - input_ids = self.tokenizer.encode(obj.text) + + if obj.input_ids is None: + input_ids = self.tokenizer.encode(obj.text) + else: + input_ids = obj.input_ids + sampling_params = SamplingParams(**obj.sampling_params) if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) @@ -204,10 +208,22 @@ class TokenizerManager: event.clear() else: assert obj.stream is False - bs = len(obj.text) + + if obj.input_ids is None: + bs = len(obj.text) + else: + bs = len(obj.input_ids) + for i in range(bs): rid = obj.rid[i] - input_ids = self.tokenizer.encode(obj.text[i]) + + if obj.input_ids is None: + input_text = obj.text[i] + input_ids = self.tokenizer.encode(obj.text[i]) + else: + input_text = None + input_ids = obj.input_ids[i] + sampling_params = SamplingParams(**obj.sampling_params[i]) if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) @@ -220,7 +236,7 @@ class TokenizerManager: ) tokenized_obj = TokenizedGenerateReqInput( rid=rid, - input_text=obj.text[i], + input_text=input_text, input_ids=input_ids, pixel_values=pixel_values, image_hash=image_hash,