Files
sglang/python/sglang/srt/managers/io_struct.py

757 lines
23 KiB
Python
Raw Normal View History

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-06-08 02:06:52 -07:00
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
2024-06-08 02:06:52 -07:00
"""
import copy
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union
2024-07-29 23:04:48 -07:00
from sglang.srt.managers.schedule_batch import BaseFinishReason
2024-08-21 16:48:24 -07:00
from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
@dataclass
class GenerateReqInput:
2024-07-19 10:58:03 -07:00
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
2024-11-25 19:35:04 -05:00
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
2024-11-25 19:35:04 -05:00
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
2024-07-19 10:58:03 -07:00
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
2025-03-25 11:08:40 +08:00
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
2024-07-19 10:58:03 -07:00
# The request id.
rid: Optional[Union[List[str], str]] = None
2024-07-19 10:58:03 -07:00
# Whether to return logprobs.
2024-01-23 05:07:30 -08:00
return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
2024-01-23 05:07:30 -08:00
logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position.
2024-03-28 14:34:49 +08:00
top_logprobs_num: Optional[Union[List[int], int]] = None
# If return logprobs, the token ids to return logprob for.
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
2024-07-19 10:58:03 -07:00
# Whether to stream output.
stream: bool = False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
2024-11-20 00:36:53 -08:00
# Whether to return hidden states
return_hidden_states: bool = False
2024-11-03 08:38:26 -08:00
def normalize_batch_and_arguments(self):
2024-11-25 19:35:04 -05:00
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
self.text is not None
and self.input_ids is not None
and self.input_embeds is not None
):
2024-11-25 19:35:04 -05:00
raise ValueError(
"Either text, input_ids or input_embeds should be provided."
)
2024-11-03 08:38:26 -08:00
# Derive the batch size
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
2024-11-03 08:38:26 -08:00
self.is_single = False
self.batch_size = len(self.text)
2024-11-25 19:35:04 -05:00
self.input_embeds = None
elif self.input_ids is not None:
2025-03-16 14:01:23 -07:00
if len(self.input_ids) == 0:
raise ValueError("input_ids cannot be empty.")
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
2024-11-03 08:38:26 -08:00
self.is_single = False
self.batch_size = len(self.input_ids)
2024-11-25 19:35:04 -05:00
self.input_embeds = None
else:
if isinstance(self.input_embeds[0][0], float):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.input_embeds)
2024-11-03 08:38:26 -08:00
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
if self.sampling_params is None:
self.parallel_sample_num = 1
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
2024-11-07 15:42:47 -08:00
assert all(
self.parallel_sample_num == sampling_params.get("n", 1)
for sampling_params in self.sampling_params
), "The parallel_sample_num should be the same for all samples in sample params."
2024-11-03 08:38:26 -08:00
if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]
# Fill in default arguments
if self.is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
2024-01-23 05:07:30 -08:00
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = -1
2024-03-28 14:34:49 +08:00
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
else:
if self.parallel_sample_num == 1:
num = self.batch_size
else:
2024-11-03 08:38:26 -08:00
# Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num
2025-03-03 03:20:23 -08:00
if not self.image_data:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list):
pass
2025-03-25 11:08:40 +08:00
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
pass
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
assert isinstance(self.rid, list), "The rid should be a list."
2024-01-23 05:07:30 -08:00
if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
else:
assert self.parallel_sample_num == 1
2024-01-23 05:07:30 -08:00
if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num
2024-01-23 05:07:30 -08:00
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1
2024-03-28 14:34:49 +08:00
if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
else:
assert self.parallel_sample_num == 1
2024-03-28 14:34:49 +08:00
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
else:
assert self.parallel_sample_num == 1
if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num
else:
assert self.parallel_sample_num == 1
# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
self.session_params[0], dict
)
2024-11-03 08:38:26 -08:00
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
def __getitem__(self, i):
return GenerateReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i],
2025-03-25 11:08:40 +08:00
audio_data=self.audio_data[i],
2024-11-03 08:38:26 -08:00
sampling_params=self.sampling_params[i],
rid=self.rid[i],
return_logprob=self.return_logprob[i],
logprob_start_len=self.logprob_start_len[i],
top_logprobs_num=self.top_logprobs_num[i],
token_ids_logprob=self.token_ids_logprob[i],
2024-11-03 08:38:26 -08:00
return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream,
log_metrics=self.log_metrics,
2024-11-03 08:38:26 -08:00
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
custom_logit_processor=(
self.custom_logit_processor[i]
if self.custom_logit_processor is not None
else None
),
return_hidden_states=self.return_hidden_states,
2024-11-03 08:38:26 -08:00
)
@dataclass
class TokenizedGenerateReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
2025-03-25 11:08:40 +08:00
# The multimodal inputs
mm_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs
2024-01-23 05:07:30 -08:00
return_logprob: bool
# If return logprobs, the start location in the prompt for returning logprobs.
2024-01-23 05:07:30 -08:00
logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position.
2024-03-28 14:34:49 +08:00
top_logprobs_num: int
# If return logprobs, the token id to return logprob for
token_ids_logprob: List[int]
# Whether to stream output
stream: bool
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
2024-11-25 19:35:04 -05:00
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session info for continual prompting
session_params: Optional[SessionParams] = None
2024-11-20 00:36:53 -08:00
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None
# Whether to return hidden states
return_hidden_states: bool = False
@dataclass
class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
image_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
2024-11-25 19:35:04 -05:00
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
2024-11-03 08:38:26 -08:00
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
raise ValueError(
"At least one of text, input_ids, or image should be provided"
)
# text and input_ids cannot be provided at the same time
if self.text is not None and self.input_ids is not None:
raise ValueError("text and input_ids cannot be provided at the same time")
2024-11-03 08:38:26 -08:00
# Derive the batch size
self.batch_size = 0
self.is_single = True
# check the batch size of text
if self.text is not None:
if isinstance(self.text, list):
self.batch_size += len(self.text)
2024-11-03 08:38:26 -08:00
else:
self.batch_size += 1
# check the batch size of input_ids
if self.input_ids is not None:
if isinstance(self.input_ids[0], list):
self.batch_size += len(self.input_ids)
2024-11-03 08:38:26 -08:00
else:
self.batch_size += 1
if self.batch_size > 1:
self.is_single = False
2024-11-03 08:38:26 -08:00
# Fill in default arguments
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
2024-08-09 11:19:18 -07:00
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 0
else:
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
2024-11-03 08:38:26 -08:00
assert isinstance(self.rid, list), "The rid should be a list."
2024-08-09 11:19:18 -07:00
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 0
2024-11-03 08:38:26 -08:00
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
2024-11-03 08:38:26 -08:00
def __getitem__(self, i):
return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i] if self.image_data is not None else None,
2024-11-03 08:38:26 -08:00
sampling_params=self.sampling_params[i],
rid=self.rid[i],
)
@dataclass
2024-11-03 08:38:26 -08:00
class TokenizedEmbeddingReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# The image inputs
image_inputs: dict
# Dummy sampling params for compatibility
sampling_params: SamplingParams
@dataclass
class BatchTokenIDOut:
# The request id
rids: List[str]
2024-12-08 12:27:13 -08:00
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
2024-06-12 14:39:12 +08:00
decoded_texts: List[str]
decode_ids: List[int]
read_offsets: List[int]
# Only used when `--skip-tokenizer-init` is on
output_ids: Optional[List[int]]
2024-12-08 12:27:13 -08:00
# Detokenization configs
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
no_stop_trim: List[bool]
2024-12-08 12:27:13 -08:00
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
2024-12-08 12:27:13 -08:00
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]]
2024-06-12 14:39:12 +08:00
@dataclass
class BatchMultimodalDecodeReq:
# The request id
rids: List[str]
2025-03-04 21:23:47 -08:00
finished_reasons: List[BaseFinishReason]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass
class BatchStrOut:
# The request id
rids: List[str]
2024-12-08 12:27:13 -08:00
# The finish reason
finished_reasons: List[dict]
# The output decoded strings
output_strs: List[str]
2025-03-04 21:23:47 -08:00
# The token ids
output_ids: Optional[List[int]]
2024-12-08 12:27:13 -08:00
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
2024-12-08 12:27:13 -08:00
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
2024-01-26 13:32:59 +08:00
# Hidden states
output_hidden_states: List[List[float]]
2024-01-26 13:32:59 +08:00
@dataclass
class BatchMultimodalOut:
# The request id
rids: List[str]
2025-03-04 21:23:47 -08:00
# The finish reason
finished_reasons: List[dict]
# The outputs
outputs: List[List[Dict]]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass
class BatchEmbeddingOut:
# The request id
rids: List[str]
2024-12-08 12:27:13 -08:00
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding
embeddings: List[List[float]]
2024-12-08 12:27:13 -08:00
# Token counts
prompt_tokens: List[int]
cached_tokens: List[int]
2024-01-26 13:32:59 +08:00
@dataclass
class FlushCacheReq:
pass
2024-02-06 12:24:55 -08:00
2024-02-06 13:27:46 -08:00
@dataclass
2024-11-29 17:17:00 -08:00
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None
@dataclass
2024-11-29 17:17:00 -08:00
class UpdateWeightFromDiskReqOutput:
success: bool
message: str
# Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0
@dataclass
class UpdateWeightsFromDistributedReqInput:
name: str
dtype: str
shape: List[int]
@dataclass
class UpdateWeightsFromDistributedReqOutput:
success: bool
message: str
@dataclass
class UpdateWeightsFromTensorReqInput:
# List containing one serialized Dict[str, torch.Tensor] per TP worker
serialized_named_tensors: List[bytes]
load_format: Optional[str]
flush_cache: bool
@dataclass
class UpdateWeightsFromTensorReqOutput:
success: bool
message: str
@dataclass
class InitWeightsUpdateGroupReqInput:
# The master address
master_address: str
# The master port
master_port: int
# The rank offset
rank_offset: int
# The world size
world_size: int
# The group name
group_name: str = "weight_update_group"
# The backend
backend: str = "nccl"
@dataclass
class InitWeightsUpdateGroupReqOutput:
success: bool
message: str
@dataclass
class GetWeightsByNameReqInput:
name: str
truncate_size: int = 100
@dataclass
class GetWeightsByNameReqOutput:
parameter: list
@dataclass
class ReleaseMemoryOccupationReqInput:
pass
@dataclass
class ReleaseMemoryOccupationReqOutput:
pass
@dataclass
class ResumeMemoryOccupationReqInput:
pass
@dataclass
class ResumeMemoryOccupationReqOutput:
pass
@dataclass
class AbortReq:
# The request id
rid: str
@dataclass
class GetInternalStateReq:
pass
@dataclass
class GetInternalStateReqOutput:
internal_state: Dict[Any, Any]
@dataclass
class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class SetInternalStateReqOutput:
updated: bool
server_args: Dict[str, Any]
@dataclass
class ProfileReqInput:
# The output directory
output_dir: Optional[str] = None
# If set, it profile as many as this number of steps.
# If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
class ProfileReqType(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass
class ProfileReq:
type: ProfileReqType
output_dir: Optional[str] = None
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
@dataclass
class ProfileReqOutput:
success: bool
message: str
@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None
2024-11-20 00:36:53 -08:00
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
session_id: Optional[str] = None
2024-11-20 00:36:53 -08:00
@dataclass
class CloseSessionReqInput:
session_id: str
@dataclass
class OpenSessionReqOutput:
session_id: Optional[str]
success: bool
@dataclass
class HealthCheckOutput:
pass
@dataclass
class Function:
description: Optional[str] = None
name: Optional[str] = None
parameters: Optional[object] = None
@dataclass
class Tool:
function: Function
type: Optional[str] = "function"
@dataclass
class ParseFunctionCallReq:
text: str # The text to parse.
tools: List[Tool] = field(
default_factory=list
) # A list of available function tools (name, parameters, etc.).
tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
@dataclass
class SeparateReasoningReqInput:
text: str # The text to parse.
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
@dataclass
class VertexGenerateReqInput:
instances: List[dict]
parameters: Optional[dict] = None
@dataclass
class RpcReqInput:
method: str
parameters: Optional[Dict] = None
@dataclass
class RpcReqOutput:
success: bool
message: str