1149 lines
38 KiB
Python
1149 lines
38 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
The definition of objects transferred between different
|
|
processes (TokenizerManager, DetokenizerManager, Scheduler).
|
|
"""
|
|
|
|
import copy
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
from sglang.srt.lora.lora_registry import LoRARef
|
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
|
from sglang.srt.utils import ImageData
|
|
|
|
# Handle serialization of Image for pydantic
|
|
if TYPE_CHECKING:
|
|
from PIL.Image import Image
|
|
else:
|
|
Image = Any
|
|
|
|
|
|
@dataclass
|
|
class SessionParams:
|
|
id: Optional[str] = None
|
|
rid: Optional[str] = None
|
|
offset: Optional[int] = None
|
|
replace: Optional[bool] = None
|
|
drop_previous_output: Optional[bool] = None
|
|
|
|
|
|
# Type definitions for multimodal input data
|
|
# Individual data item types for each modality
|
|
ImageDataInputItem = Union[Image, str, ImageData, Dict]
|
|
AudioDataInputItem = Union[str, Dict]
|
|
VideoDataInputItem = Union[str, Dict]
|
|
# Union type for any multimodal data item
|
|
MultimodalDataInputItem = Union[
|
|
ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
|
|
]
|
|
# Format types supporting single items, lists, or nested lists for batch processing
|
|
MultimodalDataInputFormat = Union[
|
|
List[List[MultimodalDataInputItem]],
|
|
List[MultimodalDataInputItem],
|
|
MultimodalDataInputItem,
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class GenerateReqInput:
|
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
text: Optional[Union[List[str], str]] = None
|
|
# The token ids for text; one can specify either text or input_ids
|
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
|
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
|
# Can be formatted as:
|
|
# - Single image for a single request
|
|
# - List of images (one per request in a batch)
|
|
# - List of lists of images (multiple images per request)
|
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
|
image_data: Optional[MultimodalDataInputFormat] = None
|
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
|
video_data: Optional[MultimodalDataInputFormat] = None
|
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
|
# The sampling_params. See descriptions below.
|
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
|
# The request id.
|
|
rid: Optional[Union[List[str], str]] = None
|
|
# Whether to return logprobs.
|
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
|
# If return logprobs, the number of top logprobs to return at each position.
|
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
|
# If return logprobs, the token ids to return logprob for.
|
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
|
|
# Whether to detokenize tokens in text in the returned logprobs.
|
|
return_text_in_logprobs: bool = False
|
|
# Whether to stream output.
|
|
stream: bool = False
|
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
|
log_metrics: bool = True
|
|
# Whether to return hidden states
|
|
return_hidden_states: Union[List[bool], bool] = False
|
|
|
|
# The modalities of the image data [image, multi-images, video]
|
|
modalities: Optional[List[str]] = None
|
|
# Session info for continual prompting
|
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
|
|
|
# The path to the LoRA adaptors
|
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
|
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
|
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
|
# Use the processor's `to_str()` method to generate the serialized string.
|
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
|
|
|
# For disaggregated inference
|
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
|
|
|
# For data parallel rank routing
|
|
data_parallel_rank: Optional[int] = None
|
|
|
|
# For background responses (OpenAI responses API)
|
|
background: bool = False
|
|
|
|
def contains_mm_input(self) -> bool:
|
|
return (
|
|
has_valid_data(self.image_data)
|
|
or has_valid_data(self.video_data)
|
|
or has_valid_data(self.audio_data)
|
|
)
|
|
|
|
def normalize_batch_and_arguments(self):
|
|
"""
|
|
Normalize the batch size and arguments for the request.
|
|
|
|
This method resolves various input formats and ensures all parameters
|
|
are properly formatted as either single values or batches depending on the input.
|
|
It also handles parallel sampling expansion and sets default values for
|
|
unspecified parameters.
|
|
|
|
Raises:
|
|
ValueError: If inputs are not properly specified (e.g., none or all of
|
|
text, input_ids, input_embeds are provided)
|
|
"""
|
|
self._validate_inputs()
|
|
self._determine_batch_size()
|
|
self._handle_parallel_sampling()
|
|
|
|
if self.is_single:
|
|
self._normalize_single_inputs()
|
|
else:
|
|
self._normalize_batch_inputs()
|
|
|
|
def _validate_inputs(self):
|
|
"""Validate that the input configuration is valid."""
|
|
if (
|
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
|
) or (
|
|
self.text is not None
|
|
and self.input_ids is not None
|
|
and self.input_embeds is not None
|
|
):
|
|
raise ValueError(
|
|
"Either text, input_ids or input_embeds should be provided."
|
|
)
|
|
|
|
def _determine_batch_size(self):
|
|
"""Determine if this is a single example or a batch and the batch size."""
|
|
if self.text is not None:
|
|
if isinstance(self.text, str):
|
|
self.is_single = True
|
|
self.batch_size = 1
|
|
else:
|
|
self.is_single = False
|
|
self.batch_size = len(self.text)
|
|
self.input_embeds = None
|
|
elif self.input_ids is not None:
|
|
if len(self.input_ids) == 0:
|
|
raise ValueError("input_ids cannot be empty.")
|
|
if isinstance(self.input_ids[0], int):
|
|
self.is_single = True
|
|
self.batch_size = 1
|
|
else:
|
|
self.is_single = False
|
|
self.batch_size = len(self.input_ids)
|
|
self.input_embeds = None
|
|
else:
|
|
if isinstance(self.input_embeds[0][0], float):
|
|
self.is_single = True
|
|
self.batch_size = 1
|
|
else:
|
|
self.is_single = False
|
|
self.batch_size = len(self.input_embeds)
|
|
|
|
def _handle_parallel_sampling(self):
|
|
"""Handle parallel sampling parameters and adjust batch size if needed."""
|
|
# Determine parallel sample count
|
|
if self.sampling_params is None:
|
|
self.parallel_sample_num = 1
|
|
return
|
|
elif isinstance(self.sampling_params, dict):
|
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
|
else: # isinstance(self.sampling_params, list):
|
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
|
for sampling_params in self.sampling_params:
|
|
if self.parallel_sample_num != sampling_params.get("n", 1):
|
|
raise ValueError(
|
|
"The parallel_sample_num should be the same for all samples in sample params."
|
|
)
|
|
|
|
# If using parallel sampling with a single example, convert to batch
|
|
if self.parallel_sample_num > 1 and self.is_single:
|
|
self.is_single = False
|
|
if self.text is not None:
|
|
self.text = [self.text]
|
|
if self.input_ids is not None:
|
|
self.input_ids = [self.input_ids]
|
|
if self.input_embeds is not None:
|
|
self.input_embeds = [self.input_embeds]
|
|
|
|
def _normalize_single_inputs(self):
|
|
"""Normalize inputs for a single example."""
|
|
if self.sampling_params is None:
|
|
self.sampling_params = {}
|
|
if self.rid is None:
|
|
self.rid = uuid.uuid4().hex
|
|
if self.return_logprob is None:
|
|
self.return_logprob = False
|
|
if self.logprob_start_len is None:
|
|
self.logprob_start_len = -1
|
|
if self.top_logprobs_num is None:
|
|
self.top_logprobs_num = 0
|
|
if not self.token_ids_logprob: # covers both None and []
|
|
self.token_ids_logprob = None
|
|
|
|
def _normalize_batch_inputs(self):
|
|
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
|
|
# Calculate expanded batch size
|
|
if self.parallel_sample_num == 1:
|
|
num = self.batch_size
|
|
else:
|
|
# Expand parallel_sample_num
|
|
num = self.batch_size * self.parallel_sample_num
|
|
|
|
# Expand input based on type
|
|
self._expand_inputs(num)
|
|
self._normalize_rid(num)
|
|
self._normalize_lora_paths(num)
|
|
self._normalize_image_data(num)
|
|
self._normalize_video_data(num)
|
|
self._normalize_audio_data(num)
|
|
self._normalize_sampling_params(num)
|
|
self._normalize_logprob_params(num)
|
|
self._normalize_custom_logit_processor(num)
|
|
|
|
def _expand_inputs(self, num):
|
|
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
|
if self.text is not None:
|
|
if not isinstance(self.text, list):
|
|
raise ValueError("Text should be a list for batch processing.")
|
|
self.text = self.text * self.parallel_sample_num
|
|
elif self.input_ids is not None:
|
|
if not isinstance(self.input_ids, list) or not isinstance(
|
|
self.input_ids[0], list
|
|
):
|
|
raise ValueError(
|
|
"input_ids should be a list of lists for batch processing."
|
|
)
|
|
self.input_ids = self.input_ids * self.parallel_sample_num
|
|
elif self.input_embeds is not None:
|
|
if not isinstance(self.input_embeds, list):
|
|
raise ValueError("input_embeds should be a list for batch processing.")
|
|
self.input_embeds = self.input_embeds * self.parallel_sample_num
|
|
|
|
def _normalize_lora_paths(self, num):
|
|
"""Normalize LoRA paths for batch processing."""
|
|
if self.lora_path is not None:
|
|
if isinstance(self.lora_path, str):
|
|
self.lora_path = [self.lora_path] * num
|
|
elif isinstance(self.lora_path, list):
|
|
self.lora_path = self.lora_path * self.parallel_sample_num
|
|
else:
|
|
raise ValueError("lora_path should be a list or a string.")
|
|
|
|
def _normalize_image_data(self, num):
|
|
"""Normalize image data for batch processing."""
|
|
if self.image_data is None:
|
|
self.image_data = [None] * num
|
|
elif not isinstance(self.image_data, list):
|
|
# Single image, convert to list of single-image lists
|
|
self.image_data = [[self.image_data]] * num
|
|
self.modalities = ["image"] * num
|
|
elif isinstance(self.image_data, list):
|
|
if len(self.image_data) != self.batch_size:
|
|
raise ValueError(
|
|
"The length of image_data should be equal to the batch size."
|
|
)
|
|
|
|
self.modalities = []
|
|
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
|
|
# Already a list of lists, keep as is
|
|
for i in range(len(self.image_data)):
|
|
if self.image_data[i] is None or self.image_data[i] == [None]:
|
|
self.modalities.append(None)
|
|
elif len(self.image_data[i]) == 1:
|
|
self.modalities.append("image")
|
|
elif len(self.image_data[i]) > 1:
|
|
self.modalities.append("multi-images")
|
|
else:
|
|
# Ensure len(self.modalities) == len(self.image_data)
|
|
self.modalities.append(None)
|
|
# Expand parallel_sample_num
|
|
self.image_data = self.image_data * self.parallel_sample_num
|
|
self.modalities = self.modalities * self.parallel_sample_num
|
|
else:
|
|
# List of images for a batch, wrap each in a list
|
|
wrapped_images = [[img] for img in self.image_data]
|
|
# Expand for parallel sampling
|
|
self.image_data = wrapped_images * self.parallel_sample_num
|
|
self.modalities = ["image"] * num
|
|
|
|
def _normalize_video_data(self, num):
|
|
"""Normalize video data for batch processing."""
|
|
if self.video_data is None:
|
|
self.video_data = [None] * num
|
|
elif not isinstance(self.video_data, list):
|
|
self.video_data = [self.video_data] * num
|
|
elif isinstance(self.video_data, list):
|
|
self.video_data = self.video_data * self.parallel_sample_num
|
|
|
|
def _normalize_audio_data(self, num):
|
|
"""Normalize audio data for batch processing."""
|
|
if self.audio_data is None:
|
|
self.audio_data = [None] * num
|
|
elif not isinstance(self.audio_data, list):
|
|
self.audio_data = [self.audio_data] * num
|
|
elif isinstance(self.audio_data, list):
|
|
self.audio_data = self.audio_data * self.parallel_sample_num
|
|
|
|
def _normalize_sampling_params(self, num):
|
|
"""Normalize sampling parameters for batch processing."""
|
|
if self.sampling_params is None:
|
|
self.sampling_params = [{}] * num
|
|
elif isinstance(self.sampling_params, dict):
|
|
self.sampling_params = [self.sampling_params] * num
|
|
else: # Already a list
|
|
self.sampling_params = self.sampling_params * self.parallel_sample_num
|
|
|
|
def _normalize_rid(self, num):
|
|
"""Normalize request IDs for batch processing."""
|
|
if self.rid is None:
|
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
|
elif isinstance(self.rid, str):
|
|
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
|
self.rid = new_rids
|
|
elif isinstance(self.rid, list):
|
|
# Note: the length of rid shall be the same as the batch_size,
|
|
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
|
if len(self.rid) != self.batch_size:
|
|
raise ValueError(
|
|
"The specified rids length mismatch with the batch_size for batch processing."
|
|
)
|
|
else:
|
|
raise ValueError("The rid should be a string or a list of strings.")
|
|
|
|
def _normalize_logprob_params(self, num):
|
|
"""Normalize logprob-related parameters for batch processing."""
|
|
|
|
# Helper function to normalize a parameter
|
|
def normalize_param(param, default_value, param_name):
|
|
if param is None:
|
|
return [default_value] * num
|
|
elif not isinstance(param, list):
|
|
return [param] * num
|
|
else:
|
|
if self.parallel_sample_num > 1:
|
|
raise ValueError(
|
|
f"Cannot use list {param_name} with parallel_sample_num > 1"
|
|
)
|
|
return param
|
|
|
|
# Normalize each logprob parameter
|
|
self.return_logprob = normalize_param(
|
|
self.return_logprob, False, "return_logprob"
|
|
)
|
|
self.logprob_start_len = normalize_param(
|
|
self.logprob_start_len, -1, "logprob_start_len"
|
|
)
|
|
self.top_logprobs_num = normalize_param(
|
|
self.top_logprobs_num, 0, "top_logprobs_num"
|
|
)
|
|
|
|
# Handle token_ids_logprob specially due to its nested structure
|
|
if not self.token_ids_logprob: # covers both None and []
|
|
self.token_ids_logprob = [None] * num
|
|
elif not isinstance(self.token_ids_logprob, list):
|
|
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
|
elif not isinstance(self.token_ids_logprob[0], list):
|
|
self.token_ids_logprob = [
|
|
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
|
]
|
|
elif self.parallel_sample_num > 1:
|
|
raise ValueError(
|
|
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
|
|
)
|
|
|
|
def _normalize_custom_logit_processor(self, num):
|
|
"""Normalize custom logit processor for batch processing."""
|
|
if self.custom_logit_processor is None:
|
|
self.custom_logit_processor = [None] * num
|
|
elif not isinstance(self.custom_logit_processor, list):
|
|
self.custom_logit_processor = [self.custom_logit_processor] * num
|
|
elif self.parallel_sample_num > 1:
|
|
raise ValueError(
|
|
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
|
)
|
|
|
|
def _validate_session_params(self):
|
|
"""Validate that session parameters are properly formatted."""
|
|
if self.session_params is not None:
|
|
if not isinstance(self.session_params, dict) and not isinstance(
|
|
self.session_params[0], dict
|
|
):
|
|
raise ValueError("Session params must be a dict or a list of dicts.")
|
|
|
|
def regenerate_rid(self):
|
|
"""Generate a new request ID and return it."""
|
|
self.rid = uuid.uuid4().hex
|
|
return self.rid
|
|
|
|
def __getitem__(self, i):
|
|
return GenerateReqInput(
|
|
text=self.text[i] if self.text is not None else None,
|
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
|
input_embeds=(
|
|
self.input_embeds[i] if self.input_embeds is not None else None
|
|
),
|
|
image_data=self.image_data[i],
|
|
video_data=self.video_data[i],
|
|
audio_data=self.audio_data[i],
|
|
sampling_params=self.sampling_params[i],
|
|
rid=self.rid[i],
|
|
return_logprob=self.return_logprob[i],
|
|
logprob_start_len=self.logprob_start_len[i],
|
|
top_logprobs_num=self.top_logprobs_num[i],
|
|
token_ids_logprob=self.token_ids_logprob[i],
|
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
|
stream=self.stream,
|
|
log_metrics=self.log_metrics,
|
|
modalities=self.modalities[i] if self.modalities else None,
|
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
|
custom_logit_processor=(
|
|
self.custom_logit_processor[i]
|
|
if self.custom_logit_processor is not None
|
|
else None
|
|
),
|
|
return_hidden_states=(
|
|
self.return_hidden_states[i]
|
|
if isinstance(self.return_hidden_states, list)
|
|
else self.return_hidden_states
|
|
),
|
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
|
bootstrap_host=(
|
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
|
),
|
|
bootstrap_port=(
|
|
self.bootstrap_port[i] if self.bootstrap_port is not None else None
|
|
),
|
|
bootstrap_room=(
|
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
|
),
|
|
data_parallel_rank=(
|
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
|
),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TokenizedGenerateReqInput:
|
|
# The request id
|
|
rid: str
|
|
# The input text
|
|
input_text: str
|
|
# The input token ids
|
|
input_ids: List[int]
|
|
# The multimodal inputs
|
|
mm_inputs: dict
|
|
# The sampling parameters
|
|
sampling_params: SamplingParams
|
|
# Whether to return the logprobs
|
|
return_logprob: bool
|
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
|
logprob_start_len: int
|
|
# If return logprobs, the number of top logprobs to return at each position.
|
|
top_logprobs_num: int
|
|
# If return logprobs, the token id to return logprob for
|
|
token_ids_logprob: List[int]
|
|
# Whether to stream output
|
|
stream: bool
|
|
|
|
# LoRA related
|
|
lora_id: Optional[str] = None # None means just use the base model
|
|
# The input embeds
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
|
|
|
# Session info for continual prompting
|
|
session_params: Optional[SessionParams] = None
|
|
|
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
|
# Use the processor's `to_str()` method to generate the serialized string.
|
|
custom_logit_processor: Optional[str] = None
|
|
|
|
# Whether to return hidden states
|
|
return_hidden_states: bool = False
|
|
|
|
# For disaggregated inference
|
|
bootstrap_host: Optional[str] = None
|
|
bootstrap_port: Optional[int] = None
|
|
bootstrap_room: Optional[int] = None
|
|
|
|
# For data parallel rank routing
|
|
data_parallel_rank: Optional[int] = None
|
|
|
|
# For dp balance
|
|
dp_balance_id: int = -1
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingReqInput:
|
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
text: Optional[Union[List[List[str]], List[str], str]] = None
|
|
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
|
# Can be formatted as:
|
|
# - Single image for a single request
|
|
# - List of images (one per request in a batch)
|
|
# - List of lists of images (multiple images per request)
|
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
|
image_data: Optional[MultimodalDataInputFormat] = None
|
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
|
video_data: Optional[MultimodalDataInputFormat] = None
|
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
|
# The token ids for text; one can either specify text or input_ids.
|
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
|
# The request id.
|
|
rid: Optional[Union[List[str], str]] = None
|
|
# Dummy sampling params for compatibility
|
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
|
# Dummy input embeds for compatibility
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
|
log_metrics: bool = True
|
|
# The modalities of the image data [image, multi-images, video]
|
|
modalities: Optional[List[str]] = None
|
|
# For cross-encoder requests
|
|
is_cross_encoder_request: bool = False
|
|
|
|
# For background responses (OpenAI responses API)
|
|
background: bool = False
|
|
|
|
def normalize_batch_and_arguments(self):
|
|
# at least one of text, input_ids, or image should be provided
|
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
|
raise ValueError(
|
|
"At least one of text, input_ids, or image should be provided"
|
|
)
|
|
|
|
# text and input_ids cannot be provided at the same time
|
|
if self.text is not None and self.input_ids is not None:
|
|
raise ValueError("text and input_ids cannot be provided at the same time")
|
|
|
|
# Derive the batch size
|
|
self.batch_size = 0
|
|
self.is_single = True
|
|
|
|
# check the batch size of text
|
|
if self.text is not None:
|
|
if isinstance(self.text, list):
|
|
self.batch_size += len(self.text)
|
|
self.is_single = False
|
|
else:
|
|
self.batch_size += 1
|
|
|
|
# check the batch size of input_ids
|
|
if self.input_ids is not None:
|
|
if isinstance(self.input_ids[0], list):
|
|
self.batch_size += len(self.input_ids)
|
|
self.is_single = False
|
|
else:
|
|
self.batch_size += 1
|
|
|
|
# Fill in default arguments
|
|
if self.is_single:
|
|
if self.rid is None:
|
|
self.rid = uuid.uuid4().hex
|
|
if self.sampling_params is None:
|
|
self.sampling_params = {}
|
|
self.sampling_params["max_new_tokens"] = 0
|
|
else:
|
|
if self.rid is None:
|
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
|
else:
|
|
assert isinstance(self.rid, list), "The rid should be a list."
|
|
|
|
if self.sampling_params is None:
|
|
self.sampling_params = [{}] * self.batch_size
|
|
elif isinstance(self.sampling_params, dict):
|
|
self.sampling_params = [self.sampling_params] * self.batch_size
|
|
for i in range(self.batch_size):
|
|
self.sampling_params[i]["max_new_tokens"] = 0
|
|
|
|
def regenerate_rid(self):
|
|
self.rid = uuid.uuid4().hex
|
|
return self.rid
|
|
|
|
def contains_mm_input(self) -> bool:
|
|
return (
|
|
has_valid_data(self.image_data)
|
|
or has_valid_data(self.video_data)
|
|
or has_valid_data(self.audio_data)
|
|
)
|
|
|
|
def __getitem__(self, i):
|
|
if self.is_cross_encoder_request:
|
|
return EmbeddingReqInput(
|
|
text=[self.text[i]] if self.text is not None else None,
|
|
sampling_params=self.sampling_params[i],
|
|
rid=self.rid[i],
|
|
is_cross_encoder_request=True,
|
|
)
|
|
|
|
return EmbeddingReqInput(
|
|
text=self.text[i] if self.text is not None else None,
|
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
|
image_data=self.image_data[i] if self.image_data is not None else None,
|
|
audio_data=self.audio_data[i] if self.audio_data is not None else None,
|
|
video_data=self.video_data[i] if self.video_data is not None else None,
|
|
sampling_params=self.sampling_params[i],
|
|
rid=self.rid[i],
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TokenizedEmbeddingReqInput:
|
|
# The request id
|
|
rid: str
|
|
# The input text
|
|
input_text: str
|
|
# The input token ids
|
|
input_ids: List[int]
|
|
# The image inputs
|
|
image_inputs: dict
|
|
# The token type ids
|
|
token_type_ids: List[int]
|
|
# Dummy sampling params for compatibility
|
|
sampling_params: SamplingParams
|
|
# For data parallel rank routing
|
|
data_parallel_rank: Optional[int] = None
|
|
# For dp balance
|
|
dp_balance_id: int = -1
|
|
|
|
|
|
@dataclass
|
|
class BatchTokenIDOut:
|
|
# The request id
|
|
rids: List[str]
|
|
# The finish reason
|
|
finished_reasons: List[BaseFinishReason]
|
|
# For incremental decoding
|
|
decoded_texts: List[str]
|
|
decode_ids: List[int]
|
|
read_offsets: List[int]
|
|
# Only used when `--skip-tokenizer-init` is on
|
|
output_ids: Optional[List[int]]
|
|
# Detokenization configs
|
|
skip_special_tokens: List[bool]
|
|
spaces_between_special_tokens: List[bool]
|
|
no_stop_trim: List[bool]
|
|
|
|
# Token counts
|
|
prompt_tokens: List[int]
|
|
completion_tokens: List[int]
|
|
cached_tokens: List[int]
|
|
spec_verify_ct: List[int]
|
|
|
|
# Logprobs
|
|
input_token_logprobs_val: List[float]
|
|
input_token_logprobs_idx: List[int]
|
|
output_token_logprobs_val: List[float]
|
|
output_token_logprobs_idx: List[int]
|
|
input_top_logprobs_val: List[List]
|
|
input_top_logprobs_idx: List[List]
|
|
output_top_logprobs_val: List[List]
|
|
output_top_logprobs_idx: List[List]
|
|
input_token_ids_logprobs_val: List[List]
|
|
input_token_ids_logprobs_idx: List[List]
|
|
output_token_ids_logprobs_val: List[List]
|
|
output_token_ids_logprobs_idx: List[List]
|
|
|
|
# Hidden states
|
|
output_hidden_states: List[List[float]]
|
|
|
|
|
|
@dataclass
|
|
class BatchMultimodalDecodeReq:
|
|
# The request id
|
|
rids: List[str]
|
|
finished_reasons: List[BaseFinishReason]
|
|
|
|
# Token counts
|
|
prompt_tokens: List[int]
|
|
completion_tokens: List[int]
|
|
cached_tokens: List[int]
|
|
|
|
|
|
@dataclass
|
|
class BatchStrOut:
|
|
# The request id
|
|
rids: List[str]
|
|
# The finish reason
|
|
finished_reasons: List[dict]
|
|
# The output decoded strings
|
|
output_strs: List[str]
|
|
# The token ids
|
|
output_ids: Optional[List[int]]
|
|
|
|
# Token counts
|
|
prompt_tokens: List[int]
|
|
completion_tokens: List[int]
|
|
cached_tokens: List[int]
|
|
spec_verify_ct: List[int]
|
|
|
|
# Logprobs
|
|
input_token_logprobs_val: List[float]
|
|
input_token_logprobs_idx: List[int]
|
|
output_token_logprobs_val: List[float]
|
|
output_token_logprobs_idx: List[int]
|
|
input_top_logprobs_val: List[List]
|
|
input_top_logprobs_idx: List[List]
|
|
output_top_logprobs_val: List[List]
|
|
output_top_logprobs_idx: List[List]
|
|
input_token_ids_logprobs_val: List[List]
|
|
input_token_ids_logprobs_idx: List[List]
|
|
output_token_ids_logprobs_val: List[List]
|
|
output_token_ids_logprobs_idx: List[List]
|
|
|
|
# Hidden states
|
|
output_hidden_states: List[List[float]]
|
|
|
|
|
|
@dataclass
|
|
class BatchMultimodalOut:
|
|
# The request id
|
|
rids: List[str]
|
|
# The finish reason
|
|
finished_reasons: List[dict]
|
|
# The outputs
|
|
outputs: List[List[Dict]]
|
|
|
|
# Token counts
|
|
prompt_tokens: List[int]
|
|
completion_tokens: List[int]
|
|
cached_tokens: List[int]
|
|
|
|
|
|
@dataclass
|
|
class BatchEmbeddingOut:
|
|
# The request id
|
|
rids: List[str]
|
|
# The finish reason
|
|
finished_reasons: List[BaseFinishReason]
|
|
# The output embedding
|
|
embeddings: List[List[float]]
|
|
# Token counts
|
|
prompt_tokens: List[int]
|
|
cached_tokens: List[int]
|
|
|
|
|
|
@dataclass
|
|
class FlushCacheReqInput:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class FlushCacheReqOutput:
|
|
success: bool
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightFromDiskReqInput:
|
|
# The model path with the new weights
|
|
model_path: str
|
|
# The format to load the weights
|
|
load_format: Optional[str] = None
|
|
# Whether to abort all requests before updating weights
|
|
abort_all_requests: bool = False
|
|
# Optional: Update weight version along with weights
|
|
weight_version: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightFromDiskReqOutput:
|
|
success: bool
|
|
message: str
|
|
# Number of paused requests during weight sync.
|
|
num_paused_requests: Optional[int] = 0
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightsFromDistributedReqInput:
|
|
names: List[str]
|
|
dtypes: List[str]
|
|
shapes: List[List[int]]
|
|
# The group name
|
|
group_name: str = "weight_update_group"
|
|
# Whether to flush the cache after updating weights
|
|
flush_cache: bool = True
|
|
# Whether to abort all requests before updating weights
|
|
abort_all_requests: bool = False
|
|
# Optional: Update weight version along with weights
|
|
weight_version: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightsFromDistributedReqOutput:
|
|
success: bool
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightsFromTensorReqInput:
|
|
"""Update model weights from tensor input.
|
|
|
|
- Tensors are serialized for transmission
|
|
- Data is structured in JSON for easy transmission over HTTP
|
|
"""
|
|
|
|
serialized_named_tensors: List[Union[str, bytes]]
|
|
# Optional format specification for loading
|
|
load_format: Optional[str] = None
|
|
# Whether to flush the cache after updating weights
|
|
flush_cache: bool = True
|
|
# Whether to abort all requests before updating weights
|
|
abort_all_requests: bool = False
|
|
# Optional: Update weight version along with weights
|
|
weight_version: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightsFromTensorReqOutput:
|
|
success: bool
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class InitWeightsUpdateGroupReqInput:
|
|
# The master address
|
|
master_address: str
|
|
# The master port
|
|
master_port: int
|
|
# The rank offset
|
|
rank_offset: int
|
|
# The world size
|
|
world_size: int
|
|
# The group name
|
|
group_name: str = "weight_update_group"
|
|
# The backend
|
|
backend: str = "nccl"
|
|
|
|
|
|
@dataclass
|
|
class InitWeightsUpdateGroupReqOutput:
|
|
success: bool
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class UpdateWeightVersionReqInput:
|
|
# The new weight version
|
|
new_version: str
|
|
# Whether to abort all running requests before updating
|
|
abort_all_requests: bool = True
|
|
|
|
|
|
@dataclass
|
|
class GetWeightsByNameReqInput:
|
|
name: str
|
|
truncate_size: int = 100
|
|
|
|
|
|
@dataclass
|
|
class GetWeightsByNameReqOutput:
|
|
parameter: list
|
|
|
|
|
|
@dataclass
|
|
class ReleaseMemoryOccupationReqInput:
|
|
# Optional tags to identify the memory region, which is primarily used for RL
|
|
# Currently we only support `weights` and `kv_cache`
|
|
tags: Optional[List[str]] = None
|
|
|
|
|
|
@dataclass
|
|
class ReleaseMemoryOccupationReqOutput:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class ResumeMemoryOccupationReqInput:
|
|
# Optional tags to identify the memory region, which is primarily used for RL
|
|
# Currently we only support `weights` and `kv_cache`
|
|
tags: Optional[List[str]] = None
|
|
|
|
|
|
@dataclass
|
|
class ResumeMemoryOccupationReqOutput:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class SlowDownReqInput:
|
|
forward_sleep_time: Optional[float]
|
|
|
|
|
|
@dataclass
|
|
class SlowDownReqOutput:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class AbortReq:
|
|
# The request id
|
|
rid: str = ""
|
|
# Whether to abort all requests
|
|
abort_all: bool = False
|
|
# The finished reason data
|
|
finished_reason: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@dataclass
|
|
class GetInternalStateReq:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class GetInternalStateReqOutput:
|
|
internal_state: Dict[Any, Any]
|
|
|
|
|
|
@dataclass
|
|
class SetInternalStateReq:
|
|
server_args: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class SetInternalStateReqOutput:
|
|
updated: bool
|
|
server_args: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class ProfileReqInput:
|
|
# The output directory
|
|
output_dir: Optional[str] = None
|
|
# If set, it profile as many as this number of steps.
|
|
# If it is set, profiling is automatically stopped after this step, and
|
|
# the caller doesn't need to run stop_profile.
|
|
start_step: Optional[int] = None
|
|
num_steps: Optional[int] = None
|
|
activities: Optional[List[str]] = None
|
|
profile_by_stage: bool = False
|
|
with_stack: Optional[bool] = None
|
|
record_shapes: Optional[bool] = None
|
|
|
|
|
|
class ProfileReqType(Enum):
|
|
START_PROFILE = 1
|
|
STOP_PROFILE = 2
|
|
|
|
|
|
@dataclass
|
|
class ProfileReq:
|
|
type: ProfileReqType
|
|
output_dir: Optional[str] = None
|
|
start_step: Optional[int] = None
|
|
num_steps: Optional[int] = None
|
|
activities: Optional[List[str]] = None
|
|
profile_by_stage: bool = False
|
|
with_stack: Optional[bool] = None
|
|
record_shapes: Optional[bool] = None
|
|
profile_id: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class ProfileReqOutput:
|
|
success: bool
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class ConfigureLoggingReq:
|
|
log_requests: Optional[bool] = None
|
|
log_requests_level: Optional[int] = None
|
|
dump_requests_folder: Optional[str] = None
|
|
dump_requests_threshold: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class OpenSessionReqInput:
|
|
capacity_of_str_len: int
|
|
session_id: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class CloseSessionReqInput:
|
|
session_id: str
|
|
|
|
|
|
@dataclass
|
|
class OpenSessionReqOutput:
|
|
session_id: Optional[str]
|
|
success: bool
|
|
|
|
|
|
@dataclass
|
|
class HealthCheckOutput:
|
|
pass
|
|
|
|
|
|
class ExpertDistributionReq(Enum):
|
|
START_RECORD = 1
|
|
STOP_RECORD = 2
|
|
DUMP_RECORD = 3
|
|
|
|
|
|
@dataclass
|
|
class ExpertDistributionReqOutput:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class Function:
|
|
description: Optional[str] = None
|
|
name: Optional[str] = None
|
|
parameters: Optional[object] = None
|
|
|
|
|
|
@dataclass
|
|
class Tool:
|
|
function: Function
|
|
type: Optional[str] = "function"
|
|
|
|
|
|
@dataclass
|
|
class ParseFunctionCallReq:
|
|
text: str # The text to parse.
|
|
tools: List[Tool] = field(
|
|
default_factory=list
|
|
) # A list of available function tools (name, parameters, etc.).
|
|
tool_call_parser: Optional[str] = (
|
|
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SeparateReasoningReqInput:
|
|
text: str # The text to parse.
|
|
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
|
|
|
|
|
@dataclass
|
|
class VertexGenerateReqInput:
|
|
instances: List[dict]
|
|
parameters: Optional[dict] = None
|
|
|
|
|
|
@dataclass
|
|
class RpcReqInput:
|
|
method: str
|
|
parameters: Optional[Dict] = None
|
|
|
|
|
|
@dataclass
|
|
class RpcReqOutput:
|
|
success: bool
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class LoadLoRAAdapterReqInput:
|
|
# The name of the lora module to newly loaded.
|
|
lora_name: str
|
|
# The path of loading.
|
|
lora_path: str
|
|
# Whether to pin the LoRA adapter in memory.
|
|
pinned: bool = False
|
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
|
lora_id: Optional[str] = None
|
|
|
|
def to_ref(self) -> LoRARef:
|
|
return LoRARef(
|
|
lora_id=self.lora_id,
|
|
lora_name=self.lora_name,
|
|
lora_path=self.lora_path,
|
|
pinned=self.pinned,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class UnloadLoRAAdapterReqInput:
|
|
# The name of lora module to unload.
|
|
lora_name: str
|
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
|
lora_id: Optional[str] = None
|
|
|
|
def to_ref(self) -> LoRARef:
|
|
return LoRARef(
|
|
lora_id=self.lora_id,
|
|
lora_name=self.lora_name,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LoRAUpdateResult:
|
|
success: bool
|
|
error_message: Optional[str] = None
|
|
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
|
|
|
|
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
|
|
|
|
|
class BlockReqType(Enum):
|
|
BLOCK = 1
|
|
UNBLOCK = 2
|
|
|
|
|
|
@dataclass
|
|
class BlockReqInput:
|
|
type: BlockReqType
|