2024-07-28 23:07:12 +10:00
|
|
|
"""
|
|
|
|
|
Copyright 2023-2024 SGLang Team
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-08 02:06:52 -07:00
|
|
|
"""
|
|
|
|
|
The definition of objects transfered between different
|
|
|
|
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
|
|
|
|
"""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
import uuid
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
2024-06-08 04:20:40 +08:00
|
|
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
2024-06-12 21:48:40 -07:00
|
|
|
from sglang.srt.sampling_params import SamplingParams
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class GenerateReqInput:
|
2024-07-19 10:58:03 -07:00
|
|
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
2024-07-20 03:11:15 -07:00
|
|
|
text: Optional[Union[List[str], str]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The token ids for text; one can either specify text or input_ids.
|
2024-05-12 12:29:00 -10:00
|
|
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
|
|
|
|
# See also python/sglang/srt/utils.py:load_image.
|
2024-01-08 04:37:50 +00:00
|
|
|
image_data: Optional[Union[List[str], str]] = None
|
2024-07-27 19:50:34 -07:00
|
|
|
# The sampling_params. See descriptions below.
|
2024-01-08 04:37:50 +00:00
|
|
|
sampling_params: Union[List[Dict], Dict] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The request id.
|
2024-01-08 04:37:50 +00:00
|
|
|
rid: Optional[Union[List[str], str]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# Whether to return logprobs.
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The start location of the prompt for return_logprob.
|
2024-01-23 05:07:30 -08:00
|
|
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The number of top logprobs to return.
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
2024-07-27 19:50:34 -07:00
|
|
|
# Whether to detokenize tokens in text in the returned logprobs.
|
2024-02-15 10:54:20 -08:00
|
|
|
return_text_in_logprobs: bool = False
|
2024-07-19 10:58:03 -07:00
|
|
|
# Whether to stream output.
|
2024-01-08 04:37:50 +00:00
|
|
|
stream: bool = False
|
|
|
|
|
|
|
|
|
|
def post_init(self):
|
2024-05-18 22:23:53 -07:00
|
|
|
if (self.text is None and self.input_ids is None) or (
|
|
|
|
|
self.text is not None and self.input_ids is not None
|
|
|
|
|
):
|
2024-05-17 05:49:31 -07:00
|
|
|
raise ValueError("Either text or input_ids should be provided.")
|
2024-07-26 21:00:51 +10:00
|
|
|
if (
|
|
|
|
|
isinstance(self.sampling_params, dict)
|
|
|
|
|
and self.sampling_params.get("n", 1) != 1
|
|
|
|
|
):
|
2024-07-20 14:10:01 +08:00
|
|
|
is_single = False
|
2024-05-12 12:29:00 -10:00
|
|
|
else:
|
2024-07-20 14:10:01 +08:00
|
|
|
if self.text is not None:
|
|
|
|
|
is_single = isinstance(self.text, str)
|
|
|
|
|
else:
|
|
|
|
|
is_single = isinstance(self.input_ids[0], int)
|
2024-05-12 12:29:00 -10:00
|
|
|
self.is_single = is_single
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if is_single:
|
|
|
|
|
if self.sampling_params is None:
|
|
|
|
|
self.sampling_params = {}
|
|
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = uuid.uuid4().hex
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.return_logprob is None:
|
|
|
|
|
self.return_logprob = False
|
|
|
|
|
if self.logprob_start_len is None:
|
|
|
|
|
self.logprob_start_len = 0
|
2024-03-28 14:34:49 +08:00
|
|
|
if self.top_logprobs_num is None:
|
|
|
|
|
self.top_logprobs_num = 0
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-07-30 04:07:18 +08:00
|
|
|
parallel_sample_num_list = []
|
|
|
|
|
if isinstance(self.sampling_params, dict):
|
|
|
|
|
parallel_sample_num = self.sampling_params.get("n", 1)
|
|
|
|
|
elif isinstance(self.sampling_params, list):
|
|
|
|
|
for sp in self.sampling_params:
|
|
|
|
|
parallel_sample_num = sp.get("n", 1)
|
|
|
|
|
parallel_sample_num_list.append(parallel_sample_num)
|
|
|
|
|
parallel_sample_num = max(parallel_sample_num_list)
|
|
|
|
|
all_equal = all(
|
|
|
|
|
element == parallel_sample_num
|
|
|
|
|
for element in parallel_sample_num_list
|
|
|
|
|
)
|
|
|
|
|
if parallel_sample_num > 1 and (not all_equal):
|
|
|
|
|
## TODO cope with the case that the parallel_sample_num is different for different samples
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The parallel_sample_num should be the same for all samples in sample params."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
parallel_sample_num = 1
|
|
|
|
|
self.parallel_sample_num = parallel_sample_num
|
2024-07-20 14:10:01 +08:00
|
|
|
|
|
|
|
|
if parallel_sample_num != 1:
|
|
|
|
|
# parallel sampling +1 represents the original prefill stage
|
|
|
|
|
num = parallel_sample_num + 1
|
|
|
|
|
if isinstance(self.text, List):
|
|
|
|
|
## suppot batch operation
|
|
|
|
|
self.batch_size = len(self.text)
|
|
|
|
|
num = num * len(self.text)
|
|
|
|
|
else:
|
|
|
|
|
self.batch_size = 1
|
|
|
|
|
else:
|
|
|
|
|
## support select operation
|
|
|
|
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
|
|
|
|
self.batch_size = num
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if self.image_data is None:
|
|
|
|
|
self.image_data = [None] * num
|
|
|
|
|
elif not isinstance(self.image_data, list):
|
|
|
|
|
self.image_data = [self.image_data] * num
|
|
|
|
|
|
|
|
|
|
if self.sampling_params is None:
|
|
|
|
|
self.sampling_params = [{}] * num
|
|
|
|
|
elif not isinstance(self.sampling_params, list):
|
|
|
|
|
self.sampling_params = [self.sampling_params] * num
|
|
|
|
|
|
|
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
|
|
|
|
else:
|
2024-05-17 05:49:31 -07:00
|
|
|
if not isinstance(self.rid, list):
|
|
|
|
|
raise ValueError("The rid should be a list.")
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.return_logprob is None:
|
|
|
|
|
self.return_logprob = [False] * num
|
|
|
|
|
elif not isinstance(self.return_logprob, list):
|
|
|
|
|
self.return_logprob = [self.return_logprob] * num
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.logprob_start_len is None:
|
|
|
|
|
self.logprob_start_len = [0] * num
|
|
|
|
|
elif not isinstance(self.logprob_start_len, list):
|
|
|
|
|
self.logprob_start_len = [self.logprob_start_len] * num
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-03-28 14:34:49 +08:00
|
|
|
if self.top_logprobs_num is None:
|
|
|
|
|
self.top_logprobs_num = [0] * num
|
|
|
|
|
elif not isinstance(self.top_logprobs_num, list):
|
|
|
|
|
self.top_logprobs_num = [self.top_logprobs_num] * num
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TokenizedGenerateReqInput:
|
|
|
|
|
rid: str
|
2024-01-25 01:16:25 +08:00
|
|
|
input_text: str
|
2024-01-08 04:37:50 +00:00
|
|
|
input_ids: List[int]
|
|
|
|
|
pixel_values: List[float]
|
|
|
|
|
image_hash: int
|
2024-01-24 01:51:21 -08:00
|
|
|
image_size: List[int]
|
2024-01-08 04:37:50 +00:00
|
|
|
sampling_params: SamplingParams
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob: bool
|
|
|
|
|
logprob_start_len: int
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_num: int
|
2024-01-08 04:37:50 +00:00
|
|
|
stream: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchTokenIDOut:
|
|
|
|
|
rids: List[str]
|
2024-07-19 16:42:06 -07:00
|
|
|
vids: List[int]
|
2024-06-12 14:39:12 +08:00
|
|
|
decoded_texts: List[str]
|
2024-07-18 17:57:40 -07:00
|
|
|
decode_ids: List[int]
|
|
|
|
|
read_offsets: List[int]
|
2024-01-08 04:37:50 +00:00
|
|
|
skip_special_tokens: List[bool]
|
2024-05-01 07:17:12 +08:00
|
|
|
spaces_between_special_tokens: List[bool]
|
2024-01-08 04:37:50 +00:00
|
|
|
meta_info: List[Dict]
|
2024-06-08 04:20:40 +08:00
|
|
|
finished_reason: List[BaseFinishReason]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
@dataclass
|
|
|
|
|
class BatchStrOut:
|
|
|
|
|
rids: List[str]
|
2024-07-07 05:53:22 +08:00
|
|
|
output_strs: List[str]
|
2024-01-08 04:37:50 +00:00
|
|
|
meta_info: List[Dict]
|
2024-06-08 04:20:40 +08:00
|
|
|
finished_reason: List[BaseFinishReason]
|
2024-01-26 13:32:59 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class FlushCacheReq:
|
|
|
|
|
pass
|
2024-02-06 12:24:55 -08:00
|
|
|
|
2024-02-06 13:27:46 -08:00
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
@dataclass
|
|
|
|
|
class AbortReq:
|
|
|
|
|
rid: str
|
|
|
|
|
|
|
|
|
|
|
2024-02-06 12:24:55 -08:00
|
|
|
@dataclass
|
|
|
|
|
class DetokenizeReqInput:
|
2024-06-08 04:20:40 +08:00
|
|
|
input_ids: List[int]
|