|
|
|
|
@@ -121,6 +121,7 @@ class GenerateReqInput:
|
|
|
|
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
|
|
|
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
|
|
|
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
|
|
|
|
bootstrap_pair_key: Optional[Union[List[str], str]] = None
|
|
|
|
|
|
|
|
|
|
# For data parallel rank routing
|
|
|
|
|
data_parallel_rank: Optional[int] = None
|
|
|
|
|
@@ -128,6 +129,15 @@ class GenerateReqInput:
|
|
|
|
|
# For background responses (OpenAI responses API)
|
|
|
|
|
background: bool = False
|
|
|
|
|
|
|
|
|
|
# Conversation id used for tracking requests
|
|
|
|
|
conversation_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Label for the request
|
|
|
|
|
label: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Image gen grpc migration
|
|
|
|
|
return_bytes: bool = False
|
|
|
|
|
|
|
|
|
|
def contains_mm_input(self) -> bool:
|
|
|
|
|
return (
|
|
|
|
|
has_valid_data(self.image_data)
|
|
|
|
|
@@ -258,6 +268,7 @@ class GenerateReqInput:
|
|
|
|
|
self._normalize_sampling_params(num)
|
|
|
|
|
self._normalize_logprob_params(num)
|
|
|
|
|
self._normalize_custom_logit_processor(num)
|
|
|
|
|
self._normalize_bootstrap_params(num)
|
|
|
|
|
|
|
|
|
|
def _expand_inputs(self, num):
|
|
|
|
|
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
|
|
|
|
@@ -297,6 +308,11 @@ class GenerateReqInput:
|
|
|
|
|
self.image_data = [[self.image_data]] * num
|
|
|
|
|
self.modalities = ["image"] * num
|
|
|
|
|
elif isinstance(self.image_data, list):
|
|
|
|
|
# Handle empty list case - treat as no images
|
|
|
|
|
if len(self.image_data) == 0:
|
|
|
|
|
self.image_data = [None] * num
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if len(self.image_data) != self.batch_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The length of image_data should be equal to the batch size."
|
|
|
|
|
@@ -421,6 +437,40 @@ class GenerateReqInput:
|
|
|
|
|
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _normalize_bootstrap_params(self, num):
|
|
|
|
|
"""Normalize bootstrap parameters for batch processing."""
|
|
|
|
|
# Normalize bootstrap_host
|
|
|
|
|
if self.bootstrap_host is None:
|
|
|
|
|
self.bootstrap_host = [None] * num
|
|
|
|
|
elif not isinstance(self.bootstrap_host, list):
|
|
|
|
|
self.bootstrap_host = [self.bootstrap_host] * num
|
|
|
|
|
elif isinstance(self.bootstrap_host, list):
|
|
|
|
|
self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
|
|
|
|
|
|
|
|
|
|
# Normalize bootstrap_port
|
|
|
|
|
if self.bootstrap_port is None:
|
|
|
|
|
self.bootstrap_port = [None] * num
|
|
|
|
|
elif not isinstance(self.bootstrap_port, list):
|
|
|
|
|
self.bootstrap_port = [self.bootstrap_port] * num
|
|
|
|
|
elif isinstance(self.bootstrap_port, list):
|
|
|
|
|
self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
|
|
|
|
|
|
|
|
|
|
# Normalize bootstrap_room
|
|
|
|
|
if self.bootstrap_room is None:
|
|
|
|
|
self.bootstrap_room = [None] * num
|
|
|
|
|
elif not isinstance(self.bootstrap_room, list):
|
|
|
|
|
self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
|
|
|
|
|
elif isinstance(self.bootstrap_room, list):
|
|
|
|
|
self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
|
|
|
|
|
|
|
|
|
|
# Normalize bootstrap_pair_key
|
|
|
|
|
if self.bootstrap_pair_key is None:
|
|
|
|
|
self.bootstrap_pair_key = [None] * num
|
|
|
|
|
elif not isinstance(self.bootstrap_pair_key, list):
|
|
|
|
|
self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
|
|
|
|
|
elif isinstance(self.bootstrap_pair_key, list):
|
|
|
|
|
self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
|
|
|
|
|
|
|
|
|
|
def _validate_session_params(self):
|
|
|
|
|
"""Validate that session parameters are properly formatted."""
|
|
|
|
|
if self.session_params is not None:
|
|
|
|
|
@@ -453,7 +503,13 @@ class GenerateReqInput:
|
|
|
|
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
|
|
|
|
stream=self.stream,
|
|
|
|
|
log_metrics=self.log_metrics,
|
|
|
|
|
return_hidden_states=(
|
|
|
|
|
self.return_hidden_states[i]
|
|
|
|
|
if isinstance(self.return_hidden_states, list)
|
|
|
|
|
else self.return_hidden_states
|
|
|
|
|
),
|
|
|
|
|
modalities=self.modalities[i] if self.modalities else None,
|
|
|
|
|
session_params=self.session_params,
|
|
|
|
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
|
|
|
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
|
|
|
|
custom_logit_processor=(
|
|
|
|
|
@@ -461,11 +517,6 @@ class GenerateReqInput:
|
|
|
|
|
if self.custom_logit_processor is not None
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
return_hidden_states=(
|
|
|
|
|
self.return_hidden_states[i]
|
|
|
|
|
if isinstance(self.return_hidden_states, list)
|
|
|
|
|
else self.return_hidden_states
|
|
|
|
|
),
|
|
|
|
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
|
|
|
|
bootstrap_host=(
|
|
|
|
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
|
|
|
|
@@ -476,9 +527,17 @@ class GenerateReqInput:
|
|
|
|
|
bootstrap_room=(
|
|
|
|
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
|
|
|
|
),
|
|
|
|
|
bootstrap_pair_key=(
|
|
|
|
|
self.bootstrap_pair_key[i]
|
|
|
|
|
if self.bootstrap_pair_key is not None
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
data_parallel_rank=(
|
|
|
|
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
|
|
|
|
),
|
|
|
|
|
conversation_id=self.conversation_id,
|
|
|
|
|
label=self.label,
|
|
|
|
|
return_bytes=self.return_bytes,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -504,27 +563,28 @@ class TokenizedGenerateReqInput:
|
|
|
|
|
token_ids_logprob: List[int]
|
|
|
|
|
# Whether to stream output
|
|
|
|
|
stream: bool
|
|
|
|
|
# Whether to return hidden states
|
|
|
|
|
return_hidden_states: bool = False
|
|
|
|
|
|
|
|
|
|
# LoRA related
|
|
|
|
|
lora_id: Optional[str] = None # None means just use the base model
|
|
|
|
|
# The input embeds
|
|
|
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
|
|
|
|
|
|
|
|
|
# Session info for continual prompting
|
|
|
|
|
session_params: Optional[SessionParams] = None
|
|
|
|
|
|
|
|
|
|
# LoRA related
|
|
|
|
|
lora_id: Optional[str] = None # None means just use the base model
|
|
|
|
|
|
|
|
|
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
|
|
|
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
|
|
|
|
# Use the processor's `to_str()` method to generate the serialized string.
|
|
|
|
|
custom_logit_processor: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Whether to return hidden states
|
|
|
|
|
return_hidden_states: bool = False
|
|
|
|
|
|
|
|
|
|
# For disaggregated inference
|
|
|
|
|
bootstrap_host: Optional[str] = None
|
|
|
|
|
bootstrap_port: Optional[int] = None
|
|
|
|
|
bootstrap_room: Optional[int] = None
|
|
|
|
|
bootstrap_pair_key: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# For data parallel rank routing
|
|
|
|
|
data_parallel_rank: Optional[int] = None
|
|
|
|
|
@@ -532,6 +592,12 @@ class TokenizedGenerateReqInput:
|
|
|
|
|
# For dp balance
|
|
|
|
|
dp_balance_id: int = -1
|
|
|
|
|
|
|
|
|
|
# Label for the request
|
|
|
|
|
label: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Image gen grpc migration
|
|
|
|
|
return_bytes: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchTokenizedGenerateReqInput:
|
|
|
|
|
@@ -738,9 +804,26 @@ class BatchTokenIDOut:
|
|
|
|
|
# Hidden states
|
|
|
|
|
output_hidden_states: List[List[float]]
|
|
|
|
|
|
|
|
|
|
# The information of placeholder tokens (e.g., image token)
|
|
|
|
|
# idx is the index of the token in the prompt after expansion.
|
|
|
|
|
# val is the length of padded tokens after expansion.
|
|
|
|
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
|
|
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchMultimodalDecodeReq:
|
|
|
|
|
decoded_ids: List[int]
|
|
|
|
|
input_token_logprobs_val: List[float]
|
|
|
|
|
input_token_logprobs_idx: List[int]
|
|
|
|
|
output_token_logprobs_val: List[float]
|
|
|
|
|
output_token_logprobs_idx: List[int]
|
|
|
|
|
read_offsets: List[int]
|
|
|
|
|
skip_special_tokens: List[bool]
|
|
|
|
|
spaces_between_special_tokens: List[bool]
|
|
|
|
|
image_resolutions: List[List[int]]
|
|
|
|
|
resize_image_resolutions: List[List[int]]
|
|
|
|
|
|
|
|
|
|
# The request id
|
|
|
|
|
rids: List[str]
|
|
|
|
|
finished_reasons: List[BaseFinishReason]
|
|
|
|
|
@@ -750,6 +833,12 @@ class BatchMultimodalDecodeReq:
|
|
|
|
|
completion_tokens: List[int]
|
|
|
|
|
cached_tokens: List[int]
|
|
|
|
|
|
|
|
|
|
# Placeholder token info
|
|
|
|
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
|
|
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
|
|
|
|
|
|
|
|
return_bytes: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchStrOut:
|
|
|
|
|
@@ -785,6 +874,9 @@ class BatchStrOut:
|
|
|
|
|
# Hidden states
|
|
|
|
|
output_hidden_states: List[List[float]]
|
|
|
|
|
|
|
|
|
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
|
|
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchMultimodalOut:
|
|
|
|
|
@@ -792,14 +884,26 @@ class BatchMultimodalOut:
|
|
|
|
|
rids: List[str]
|
|
|
|
|
# The finish reason
|
|
|
|
|
finished_reasons: List[dict]
|
|
|
|
|
decoded_ids: List[List[int]]
|
|
|
|
|
# The outputs
|
|
|
|
|
outputs: List[List[Dict]]
|
|
|
|
|
outputs: Union[List[str | bytes], List[List[Dict]]]
|
|
|
|
|
|
|
|
|
|
# probability values for input tokens and output tokens
|
|
|
|
|
input_token_logprobs_val: List[List[float]]
|
|
|
|
|
input_token_logprobs_idx: List[List[int]]
|
|
|
|
|
output_token_logprobs_val: List[List[float]]
|
|
|
|
|
output_token_logprobs_idx: List[List[int]]
|
|
|
|
|
|
|
|
|
|
# Token counts
|
|
|
|
|
prompt_tokens: List[int]
|
|
|
|
|
completion_tokens: List[int]
|
|
|
|
|
cached_tokens: List[int]
|
|
|
|
|
|
|
|
|
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
|
|
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
|
|
|
|
|
|
|
|
return_bytes: List[bool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchEmbeddingOut:
|
|
|
|
|
@@ -812,6 +916,9 @@ class BatchEmbeddingOut:
|
|
|
|
|
# Token counts
|
|
|
|
|
prompt_tokens: List[int]
|
|
|
|
|
cached_tokens: List[int]
|
|
|
|
|
# Placeholder token info
|
|
|
|
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
|
|
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -844,6 +951,12 @@ class UpdateWeightFromDiskReqInput:
|
|
|
|
|
abort_all_requests: bool = False
|
|
|
|
|
# Optional: Update weight version along with weights
|
|
|
|
|
weight_version: Optional[str] = None
|
|
|
|
|
# Whether to update weights asynchronously
|
|
|
|
|
is_async: bool = False
|
|
|
|
|
# Whether to empty torch cache
|
|
|
|
|
torch_empty_cache: bool = False
|
|
|
|
|
# Whether to keep the scheduler paused after weight update
|
|
|
|
|
keep_pause: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -983,6 +1096,7 @@ class AbortReq:
|
|
|
|
|
abort_all: bool = False
|
|
|
|
|
# The finished reason data
|
|
|
|
|
finished_reason: Optional[Dict[str, Any]] = None
|
|
|
|
|
abort_reason: Optional[str] = None
|
|
|
|
|
# used in MultiTokenzierManager mode
|
|
|
|
|
rids: Optional[Union[List[str], str]] = None
|
|
|
|
|
|
|
|
|
|
@@ -1061,6 +1175,7 @@ class ConfigureLoggingReq:
|
|
|
|
|
log_requests_level: Optional[int] = None
|
|
|
|
|
dump_requests_folder: Optional[str] = None
|
|
|
|
|
dump_requests_threshold: Optional[int] = None
|
|
|
|
|
crash_dump_folder: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|