diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 1296e810a..e2eb541a4 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -22,8 +22,8 @@ import zmq.asyncio from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, - BatchEmbeddingOut, - BatchTokenIDOut, + BatchEmbeddingOutput, + BatchTokenIDOutput, HealthCheckOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -467,9 +467,9 @@ class GrpcRequestManager: await self.is_pause_cond.wait() # Handle different output types - if isinstance(recv_obj, BatchTokenIDOut): + if isinstance(recv_obj, BatchTokenIDOutput): await self._handle_batch_output(recv_obj) - elif isinstance(recv_obj, BatchEmbeddingOut): + elif isinstance(recv_obj, BatchEmbeddingOutput): await self._handle_embedding_output(recv_obj) elif isinstance(recv_obj, HealthCheckOutput): await self._handle_health_check_output(recv_obj) @@ -498,7 +498,7 @@ class GrpcRequestManager: def _convert_logprob_style( self, state: GrpcReqState, - batch_out: BatchTokenIDOut, + batch_out: BatchTokenIDOutput, batch_index: int, ): """ @@ -545,7 +545,7 @@ class GrpcRequestManager: batch_out.output_top_logprobs_idx[batch_index] ) - async def _handle_batch_output(self, batch_out: BatchTokenIDOut): + async def _handle_batch_output(self, batch_out: BatchTokenIDOutput): """Handle batch generation output from scheduler.""" # Process each request in the batch for i, rid in enumerate(batch_out.rids): @@ -666,7 +666,7 @@ class GrpcRequestManager: asyncio.create_task(cleanup()) - async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut): + async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput): """Handle batch embedding output from scheduler.""" for i, rid in enumerate(batch_out.rids): if rid not in self.rid_to_state: diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index ea0d9799b..52e2fc547 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import ( VertexGenerateReqInput, ) from sglang.srt.managers.multi_tokenizer_mixin import ( - MultiTokenizerManager, MultiTokenizerRouter, + TokenizerWorker, get_main_process_id, monkey_patch_uvicorn_multiprocessing, read_from_shared_memory, @@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) # Store global states @dataclasses.dataclass class _GlobalState: - tokenizer_manager: Union[ - TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager - ] + tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker] template_manager: TemplateManager scheduler_info: Dict @@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs: ) # Launch multi-tokenizer manager process - tokenizer_manager = MultiTokenizerManager(server_args, port_args) + tokenizer_manager = TokenizerWorker(server_args, port_args) template_manager = TemplateManager() template_manager.initialize_templates( tokenizer_manager=tokenizer_manager, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index cabc8cb3b..1af4bea4f 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -35,7 +35,7 @@ from sglang.srt.lora.utils import ( get_normalized_target_modules, get_target_module_name, ) -from sglang.srt.managers.io_struct import LoRAUpdateResult +from sglang.srt.managers.io_struct import LoRAUpdateOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs from sglang.srt.utils import replace_submodule @@ -107,8 +107,8 @@ class LoRAManager: def create_lora_update_result( self, success: bool, error_message: str = "" - ) -> LoRAUpdateResult: - return LoRAUpdateResult( + ) -> LoRAUpdateOutput: + return LoRAUpdateOutput( success=success, error_message=error_message, loaded_adapters={ @@ -117,7 +117,7 @@ class LoRAManager: }, ) - def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: + def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: """ Load a single LoRA adapter from the specified path. @@ -174,7 +174,7 @@ class LoRAManager: "`--max-loras-per-batch` or load it as unpinned LoRA adapters." ) - def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: + def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: """ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and delete the corresponding LoRA modules. diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 3c5fd4420..0169bd99a 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -26,11 +26,11 @@ import zmq from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import ( - BatchEmbeddingOut, + BatchEmbeddingOutput, BatchMultimodalDecodeReq, - BatchMultimodalOut, - BatchStrOut, - BatchTokenIDOut, + BatchMultimodalOutput, + BatchStrOutput, + BatchTokenIDOutput, FreezeGCReq, MultiTokenizerRegisterReq, ) @@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): self._request_dispatcher = TypeBasedDispatcher( [ - (BatchEmbeddingOut, self.handle_batch_embedding_out), - (BatchTokenIDOut, self.handle_batch_token_id_out), + (BatchEmbeddingOutput, self.handle_batch_embedding_out), + (BatchTokenIDOutput, self.handle_batch_token_id_out), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (MultiTokenizerRegisterReq, lambda x: x), (FreezeGCReq, self.handle_freeze_gc_req), @@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): return output[:-1] return output - def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): + def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput): # If it is embedding model, no detokenization is needed. return recv_obj - def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): + def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): bs = len(recv_obj.rids) # Initialize decode status @@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): s.sent_offset = len(output_str) output_strs.append(incremental_output) - return BatchStrOut( + return BatchStrOutput( rids=recv_obj.rids, finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, @@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): outputs = self.tokenizer.detokenize(recv_obj) - return BatchMultimodalOut( + return BatchMultimodalOutput( rids=recv_obj.rids, finished_reasons=recv_obj.finished_reasons, outputs=outputs, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 791c39399..f9c1c87cb 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler). import copy import uuid +from abc import ABC from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -36,10 +37,32 @@ else: # Parameters for a session +@dataclass +class BaseReq(ABC): + rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) + + def regenerate_rid(self): + """Generate a new request ID and return it.""" + if isinstance(self.rid, list): + self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))] + else: + self.rid = uuid.uuid4().hex + return self.rid + + +@dataclass +class BaseBatchReq(ABC): + rids: Optional[List[str]] = field(default=None, kw_only=True) + + def regenerate_rids(self): + """Generate new request IDs and return them.""" + self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))] + return self.rids + + @dataclass class SessionParams: id: Optional[str] = None - rid: Optional[str] = None offset: Optional[int] = None replace: Optional[bool] = None drop_previous_output: Optional[bool] = None @@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[ @dataclass -class GenerateReqInput: +class GenerateReqInput(BaseReq): # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None # The token ids for text; one can specify either text or input_ids @@ -83,8 +106,6 @@ class GenerateReqInput: audio_data: Optional[MultimodalDataInputFormat] = None # The sampling_params. See descriptions below. sampling_params: Optional[Union[List[Dict], Dict]] = None - # The request id. - rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None # If return logprobs, the start location in the prompt for returning logprobs. @@ -491,11 +512,6 @@ class GenerateReqInput: ): raise ValueError("Session params must be a dict or a list of dicts.") - def regenerate_rid(self): - """Generate a new request ID and return it.""" - self.rid = uuid.uuid4().hex - return self.rid - def __getitem__(self, i): return GenerateReqInput( text=self.text[i] if self.text is not None else None, @@ -558,9 +574,7 @@ class GenerateReqInput: @dataclass -class TokenizedGenerateReqInput: - # The request id - rid: str +class TokenizedGenerateReqInput(BaseReq): # The input text input_text: str # The input token ids @@ -625,7 +639,7 @@ class TokenizedGenerateReqInput: @dataclass -class BatchTokenizedGenerateReqInput: +class BatchTokenizedGenerateReqInput(BaseBatchReq): # The batch of tokenized requests batch: List[TokenizedGenerateReqInput] @@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput: @dataclass -class EmbeddingReqInput: +class EmbeddingReqInput(BaseReq): # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[List[str]], List[str], str]] = None # The image input. It can be an image instance, file name, URL, or base64 encoded string. @@ -656,8 +670,6 @@ class EmbeddingReqInput: audio_data: Optional[MultimodalDataInputFormat] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None - # The request id. - rid: Optional[Union[List[str], str]] = None # Dummy sampling params for compatibility sampling_params: Optional[Union[List[Dict], Dict]] = None # Dummy input embeds for compatibility @@ -728,10 +740,6 @@ class EmbeddingReqInput: for i in range(self.batch_size): self.sampling_params[i]["max_new_tokens"] = 0 - def regenerate_rid(self): - self.rid = uuid.uuid4().hex - return self.rid - def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -760,9 +768,7 @@ class EmbeddingReqInput: @dataclass -class TokenizedEmbeddingReqInput: - # The request id - rid: str +class TokenizedEmbeddingReqInput(BaseReq): # The input text input_text: str # The input token ids @@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput: @dataclass -class BatchTokenizedEmbeddingReqInput: +class BatchTokenizedEmbeddingReqInput(BaseBatchReq): # The batch of tokenized embedding requests batch: List[TokenizedEmbeddingReqInput] @@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput: @dataclass -class BatchTokenIDOut: - # The request id - rids: List[str] +class BatchTokenIDOutput(BaseBatchReq): # The finish reason finished_reasons: List[BaseFinishReason] # For incremental decoding @@ -842,7 +846,7 @@ class BatchTokenIDOut: @dataclass -class BatchMultimodalDecodeReq: +class BatchMultimodalDecodeReq(BaseBatchReq): decoded_ids: List[int] input_token_logprobs_val: List[float] input_token_logprobs_idx: List[int] @@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq: image_resolutions: List[List[int]] resize_image_resolutions: List[List[int]] - # The request id - rids: List[str] finished_reasons: List[BaseFinishReason] # Token counts @@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq: @dataclass -class BatchStrOut: - # The request id - rids: List[str] +class BatchStrOutput(BaseBatchReq): # The finish reason finished_reasons: List[dict] # The output decoded strings @@ -909,9 +909,7 @@ class BatchStrOut: @dataclass -class BatchMultimodalOut: - # The request id - rids: List[str] +class BatchMultimodalOutput(BaseBatchReq): # The finish reason finished_reasons: List[dict] decoded_ids: List[List[int]] @@ -936,9 +934,7 @@ class BatchMultimodalOut: @dataclass -class BatchEmbeddingOut: - # The request id - rids: List[str] +class BatchEmbeddingOutput(BaseBatchReq): # The finish reason finished_reasons: List[BaseFinishReason] # The output embedding @@ -952,27 +948,27 @@ class BatchEmbeddingOut: @dataclass -class ClearHiCacheReqInput: +class ClearHiCacheReqInput(BaseReq): pass @dataclass -class ClearHiCacheReqOutput: +class ClearHiCacheReqOutput(BaseReq): success: bool @dataclass -class FlushCacheReqInput: +class FlushCacheReqInput(BaseReq): pass @dataclass -class FlushCacheReqOutput: +class FlushCacheReqOutput(BaseReq): success: bool @dataclass -class UpdateWeightFromDiskReqInput: +class UpdateWeightFromDiskReqInput(BaseReq): # The model path with the new weights model_path: str # The format to load the weights @@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput: @dataclass -class UpdateWeightFromDiskReqOutput: +class UpdateWeightFromDiskReqOutput(BaseReq): success: bool message: str # Number of paused requests during weight sync. @@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput: @dataclass -class UpdateWeightsFromDistributedReqInput: +class UpdateWeightsFromDistributedReqInput(BaseReq): names: List[str] dtypes: List[str] shapes: List[List[int]] @@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput: @dataclass -class UpdateWeightsFromDistributedReqOutput: +class UpdateWeightsFromDistributedReqOutput(BaseReq): success: bool message: str @dataclass -class UpdateWeightsFromTensorReqInput: +class UpdateWeightsFromTensorReqInput(BaseReq): """Update model weights from tensor input. - Tensors are serialized for transmission @@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput: @dataclass -class UpdateWeightsFromTensorReqOutput: +class UpdateWeightsFromTensorReqOutput(BaseReq): success: bool message: str @dataclass -class InitWeightsSendGroupForRemoteInstanceReqInput: +class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq): # The master address master_address: str # The ports for each rank's communication group @@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput: @dataclass -class InitWeightsSendGroupForRemoteInstanceReqOutput: +class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): success: bool message: str @dataclass -class SendWeightsToRemoteInstanceReqInput: +class SendWeightsToRemoteInstanceReqInput(BaseReq): # The master address master_address: str # The ports for each rank's communication group @@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput: @dataclass -class SendWeightsToRemoteInstanceReqOutput: +class SendWeightsToRemoteInstanceReqOutput(BaseReq): success: bool message: str @dataclass -class InitWeightsUpdateGroupReqInput: +class InitWeightsUpdateGroupReqInput(BaseReq): # The master address master_address: str # The master port @@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput: @dataclass -class InitWeightsUpdateGroupReqOutput: +class InitWeightsUpdateGroupReqOutput(BaseReq): success: bool message: str @dataclass -class DestroyWeightsUpdateGroupReqInput: +class DestroyWeightsUpdateGroupReqInput(BaseReq): group_name: str = "weight_update_group" @dataclass -class DestroyWeightsUpdateGroupReqOutput: +class DestroyWeightsUpdateGroupReqOutput(BaseReq): success: bool message: str @dataclass -class UpdateWeightVersionReqInput: +class UpdateWeightVersionReqInput(BaseReq): # The new weight version new_version: str # Whether to abort all running requests before updating @@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput: @dataclass -class GetWeightsByNameReqInput: +class GetWeightsByNameReqInput(BaseReq): name: str truncate_size: int = 100 @dataclass -class GetWeightsByNameReqOutput: +class GetWeightsByNameReqOutput(BaseReq): parameter: list @dataclass -class ReleaseMemoryOccupationReqInput: +class ReleaseMemoryOccupationReqInput(BaseReq): # Optional tags to identify the memory region, which is primarily used for RL # Currently we only support `weights` and `kv_cache` tags: Optional[List[str]] = None @dataclass -class ReleaseMemoryOccupationReqOutput: +class ReleaseMemoryOccupationReqOutput(BaseReq): pass @dataclass -class ResumeMemoryOccupationReqInput: +class ResumeMemoryOccupationReqInput(BaseReq): # Optional tags to identify the memory region, which is primarily used for RL # Currently we only support `weights` and `kv_cache` tags: Optional[List[str]] = None @dataclass -class ResumeMemoryOccupationReqOutput: +class ResumeMemoryOccupationReqOutput(BaseReq): pass @dataclass -class SlowDownReqInput: +class SlowDownReqInput(BaseReq): forward_sleep_time: Optional[float] @dataclass -class SlowDownReqOutput: +class SlowDownReqOutput(BaseReq): pass @dataclass -class AbortReq: - # The request id - rid: str = "" +class AbortReq(BaseReq): # Whether to abort all requests abort_all: bool = False # The finished reason data finished_reason: Optional[Dict[str, Any]] = None abort_reason: Optional[str] = None - # used in MultiTokenzierManager mode - rids: Optional[Union[List[str], str]] = None def __post_init__(self): - self.rids = self.rid + # FIXME: This is a hack to keep the same with the old code + if self.rid is None: + self.rid = "" @dataclass -class GetInternalStateReq: +class GetInternalStateReq(BaseReq): pass @dataclass -class GetInternalStateReqOutput: +class GetInternalStateReqOutput(BaseReq): internal_state: Dict[Any, Any] @dataclass -class SetInternalStateReq: +class SetInternalStateReq(BaseReq): server_args: Dict[str, Any] @dataclass -class SetInternalStateReqOutput: +class SetInternalStateReqOutput(BaseReq): updated: bool server_args: Dict[str, Any] @dataclass -class ProfileReqInput: +class ProfileReqInput(BaseReq): # The output directory output_dir: Optional[str] = None # If set, it profile as many as this number of steps. @@ -1225,7 +1219,7 @@ class ProfileReqType(Enum): @dataclass -class ProfileReq: +class ProfileReq(BaseReq): type: ProfileReqType output_dir: Optional[str] = None start_step: Optional[int] = None @@ -1238,18 +1232,18 @@ class ProfileReq: @dataclass -class ProfileReqOutput: +class ProfileReqOutput(BaseReq): success: bool message: str @dataclass -class FreezeGCReq: +class FreezeGCReq(BaseReq): pass @dataclass -class ConfigureLoggingReq: +class ConfigureLoggingReq(BaseReq): log_requests: Optional[bool] = None log_requests_level: Optional[int] = None dump_requests_folder: Optional[str] = None @@ -1258,35 +1252,39 @@ class ConfigureLoggingReq: @dataclass -class OpenSessionReqInput: +class OpenSessionReqInput(BaseReq): capacity_of_str_len: int session_id: Optional[str] = None @dataclass -class CloseSessionReqInput: +class CloseSessionReqInput(BaseReq): session_id: str @dataclass -class OpenSessionReqOutput: +class OpenSessionReqOutput(BaseReq): session_id: Optional[str] success: bool @dataclass -class HealthCheckOutput: +class HealthCheckOutput(BaseReq): pass -class ExpertDistributionReq(Enum): +class ExpertDistributionReqType(Enum): START_RECORD = 1 STOP_RECORD = 2 DUMP_RECORD = 3 +class ExpertDistributionReq(BaseReq): + action: ExpertDistributionReqType + + @dataclass -class ExpertDistributionReqOutput: +class ExpertDistributionReqOutput(BaseReq): pass @@ -1304,7 +1302,7 @@ class Tool: @dataclass -class ParseFunctionCallReq: +class ParseFunctionCallReq(BaseReq): text: str # The text to parse. tools: List[Tool] = field( default_factory=list @@ -1315,31 +1313,31 @@ class ParseFunctionCallReq: @dataclass -class SeparateReasoningReqInput: +class SeparateReasoningReqInput(BaseReq): text: str # The text to parse. reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1". @dataclass -class VertexGenerateReqInput: +class VertexGenerateReqInput(BaseReq): instances: List[dict] parameters: Optional[dict] = None @dataclass -class RpcReqInput: +class RpcReqInput(BaseReq): method: str parameters: Optional[Dict] = None @dataclass -class RpcReqOutput: +class RpcReqOutput(BaseReq): success: bool message: str @dataclass -class LoadLoRAAdapterReqInput: +class LoadLoRAAdapterReqInput(BaseReq): # The name of the lora module to newly loaded. lora_name: str # The path of loading. @@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput: @dataclass -class UnloadLoRAAdapterReqInput: +class UnloadLoRAAdapterReqInput(BaseReq): # The name of lora module to unload. lora_name: str # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. @@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput: @dataclass -class LoRAUpdateResult: +class LoRAUpdateOutput(BaseReq): success: bool error_message: Optional[str] = None loaded_adapters: Optional[Dict[str, LoRARef]] = None -LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult +LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput @dataclass -class MultiTokenizerRegisterReq: - rids: Optional[Union[List[str], str]] = None +class MultiTokenizerRegisterReq(BaseBatchReq): ipc_name: Optional[str] = None @dataclass class MultiTokenizerWrapper: + # FIXME(lsyin): remove this worker_id: int obj: Optional[Any] = None @@ -1400,17 +1398,17 @@ class BlockReqType(Enum): @dataclass -class BlockReqInput: +class BlockReqInput(BaseReq): type: BlockReqType @dataclass -class GetLoadReqInput: +class GetLoadReqInput(BaseReq): pass @dataclass -class GetLoadReqOutput: +class GetLoadReqOutput(BaseReq): dp_rank: int num_reqs: int num_waiting_reqs: int @@ -1418,5 +1416,31 @@ class GetLoadReqOutput: @dataclass -class WatchLoadUpdateReq: +class WatchLoadUpdateReq(BaseReq): loads: List[GetLoadReqOutput] + + +def _check_all_req_types(): + """A helper function to check all request types are defined in this file.""" + import inspect + import sys + + all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) + for class_type in all_classes: + # check its name + name = class_type[0] + is_io_struct = ( + name.endswith("Req") or name.endswith("Input") or name.endswith("Output") + ) + is_base_req = issubclass(class_type[1], BaseReq) or issubclass( + class_type[1], BaseBatchReq + ) + if is_io_struct and not is_base_req: + raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.") + if is_base_req and not is_io_struct: + raise ValueError( + f"{name} is a subclass of BaseReq but not follow the naming convention." + ) + + +_check_all_req_types() diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 2d734ab2b..83c966ec6 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" +"""Mixin class and utils for multi-http-worker mode""" import asyncio import logging import multiprocessing as multiprocessing @@ -30,10 +30,10 @@ import zmq.asyncio from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( - BatchEmbeddingOut, - BatchMultimodalOut, - BatchStrOut, - BatchTokenIDOut, + BatchEmbeddingOutput, + BatchMultimodalOutput, + BatchStrOutput, + BatchTokenIDOutput, MultiTokenizerRegisterReq, MultiTokenizerWrapper, ) @@ -83,8 +83,8 @@ class SocketMapping: def _handle_output_by_index(output, i): """NOTE: A maintainable method is better here.""" - if isinstance(output, BatchTokenIDOut): - new_output = BatchTokenIDOut( + if isinstance(output, BatchTokenIDOutput): + new_output = BatchTokenIDOutput( rids=[output.rids[i]], finished_reasons=( [output.finished_reasons[i]] @@ -198,8 +198,8 @@ def _handle_output_by_index(output, i): placeholder_tokens_idx=None, placeholder_tokens_val=None, ) - elif isinstance(output, BatchEmbeddingOut): - new_output = BatchEmbeddingOut( + elif isinstance(output, BatchEmbeddingOutput): + new_output = BatchEmbeddingOutput( rids=[output.rids[i]], finished_reasons=( [output.finished_reasons[i]] @@ -216,8 +216,8 @@ def _handle_output_by_index(output, i): placeholder_tokens_idx=None, placeholder_tokens_val=None, ) - elif isinstance(output, BatchStrOut): - new_output = BatchStrOut( + elif isinstance(output, BatchStrOutput): + new_output = BatchStrOutput( rids=[output.rids[i]], finished_reasons=( [output.finished_reasons[i]] @@ -314,8 +314,8 @@ def _handle_output_by_index(output, i): placeholder_tokens_idx=None, placeholder_tokens_val=None, ) - elif isinstance(output, BatchMultimodalOut): - new_output = BatchMultimodalOut( + elif isinstance(output, BatchMultimodalOutput): + new_output = BatchMultimodalOutput( rids=[output.rids[i]], finished_reasons=( [output.finished_reasons[i]] @@ -343,7 +343,7 @@ def _handle_output_by_index(output, i): class MultiHttpWorkerDetokenizerMixin: - """Mixin class for MultiTokenizerManager and DetokenizerManager""" + """Mixin class for DetokenizerManager""" def get_worker_ids_from_req_rids(self, rids): if isinstance(rids, list): @@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin: class MultiTokenizerRouter: - """A router to receive requests from MultiTokenizerManager""" + """A router to receive requests from TokenizerWorker""" def __init__( self, @@ -454,8 +454,8 @@ class MultiTokenizerRouter: self.socket_mapping.send_output(worker_id, new_recv_obj) -class MultiTokenizerManager(TokenizerManager): - """Multi Process Tokenizer Manager that tokenizes the text.""" +class TokenizerWorker(TokenizerManager): + """Tokenizer Worker in multi-http-worker mode""" def __init__( self, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 36867abf3..2450cd46a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import ( DestroyWeightsUpdateGroupReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, + ExpertDistributionReqType, FlushCacheReqInput, FlushCacheReqOutput, FreezeGCReq, @@ -1487,12 +1488,12 @@ class Scheduler( req.priority = -sys.maxsize - 1 elif not self.enable_priority_scheduling and req.priority is not None: abort_req = AbortReq( - req.rid, finished_reason={ "type": "abort", "status_code": HTTPStatus.SERVICE_UNAVAILABLE, "message": "Using priority is disabled for this server. Please send a new request without a priority.", }, + rid=req.rid, ) self.send_to_tokenizer.send_pyobj(abort_req) @@ -1528,12 +1529,12 @@ class Scheduler( self.send_to_tokenizer.send_pyobj( AbortReq( - req_to_abort.rid, finished_reason={ "type": "abort", "status_code": HTTPStatus.SERVICE_UNAVAILABLE, "message": message, }, + rid=req_to_abort.rid, ) ) return req_to_abort.rid == recv_req.rid @@ -2005,7 +2006,7 @@ class Scheduler( self.new_token_ratio = new_token_ratio for req in reqs_to_abort: self.send_to_tokenizer.send_pyobj( - AbortReq(req.rid, abort_reason=req.to_abort_message) + AbortReq(abort_reason=req.to_abort_message, rid=req.rid) ) logger.info( @@ -2575,7 +2576,7 @@ class Scheduler( if self.enable_hicache_storage: # to release prefetch events associated with the request self.tree_cache.release_aborted_request(req.rid) - self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) + self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid)) # For disaggregation decode mode, the request in the waiting queue has KV cache allocated. if self.disaggregation_mode == DisaggregationMode.DECODE: self.tree_cache.cache_finished_req(req) @@ -2687,11 +2688,12 @@ class Scheduler( return SlowDownReqOutput() def expert_distribution_handle(self, recv_req: ExpertDistributionReq): - if recv_req == ExpertDistributionReq.START_RECORD: + action = recv_req.action + if action == ExpertDistributionReqType.START_RECORD: get_global_expert_distribution_recorder().start_record() - elif recv_req == ExpertDistributionReq.STOP_RECORD: + elif action == ExpertDistributionReqType.STOP_RECORD: get_global_expert_distribution_recorder().stop_record() - elif recv_req == ExpertDistributionReq.DUMP_RECORD: + elif action == ExpertDistributionReqType.DUMP_RECORD: get_global_expert_distribution_recorder().dump_record() else: raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}") @@ -2774,7 +2776,8 @@ class IdleSleeper: def is_health_check_generate_req(recv_req): - return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") + rid = getattr(recv_req, "rid", None) + return rid is not None and rid.startswith("HEALTH_CHECK") def is_work_request(recv_req): diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 537dedc95..e307a6899 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -9,7 +9,11 @@ import torch from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, + BatchTokenIDOutput, +) from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch if TYPE_CHECKING: @@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin: logger.error( f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" ) - self.abort_request(AbortReq(req.rid)) + self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished @@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin: logger.error( f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" ) - self.abort_request(AbortReq(req.rid)) + self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() self.set_next_batch_sampling_info_done(batch) @@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin: return self.send_to_detokenizer.send_pyobj( - BatchTokenIDOut( - rids, + BatchTokenIDOutput( finished_reasons, decoded_texts, decode_ids_list, @@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin: output_token_ids_logprobs_val, output_token_ids_logprobs_idx, output_hidden_states, + rids=rids, placeholder_tokens_idx=None, placeholder_tokens_val=None, ) @@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin: prompt_tokens.append(len(req.origin_input_ids)) cached_tokens.append(req.cached_tokens) self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - rids, + BatchEmbeddingOutput( finished_reasons, embeddings, prompt_tokens, cached_tokens, + rids=rids, placeholder_tokens_idx=None, placeholder_tokens_val=None, ) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index c8df235cb..cc929e5a7 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import ( DestroyWeightsUpdateGroupReqOutput, ExpertDistributionReq, ExpertDistributionReqOutput, + ExpertDistributionReqType, FlushCacheReqInput, FlushCacheReqOutput, GetInternalStateReq, @@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqOutput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, - LoRAUpdateResult, + LoRAUpdateOutput, MultiTokenizerWrapper, OpenSessionReqInput, ProfileReq, @@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin: self.expert_distribution_communicator.handle_recv, ), ( - LoRAUpdateResult, + LoRAUpdateOutput, self.update_lora_adapter_communicator.handle_recv, ), ( @@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin: async def start_expert_distribution_record(self: TokenizerManager): self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD) + req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD) + await self.expert_distribution_communicator(req) async def stop_expert_distribution_record(self: TokenizerManager): self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) + req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD) + await self.expert_distribution_communicator(req) async def dump_expert_distribution_record(self: TokenizerManager): self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD) + await self.expert_distribution_communicator(req) async def init_weights_update_group( self: TokenizerManager, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 65fccb1dc..a003e8ae4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import ( get_tokenizer, get_tokenizer_from_processor, ) -from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry +from sglang.srt.lora.lora_registry import LoRARegistry from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, - BatchEmbeddingOut, - BatchMultimodalOut, - BatchStrOut, - BatchTokenIDOut, + BatchEmbeddingOutput, + BatchMultimodalOutput, + BatchStrOutput, + BatchTokenIDOutput, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, - CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, FreezeGCReq, @@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import ( GetLoadReqInput, HealthCheckOutput, MultiTokenizerWrapper, - OpenSessionReqInput, OpenSessionReqOutput, SessionParams, TokenizedEmbeddingReqInput, @@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): [ ( ( - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - BatchMultimodalOut, + BatchStrOutput, + BatchEmbeddingOutput, + BatchTokenIDOutput, + BatchMultimodalOutput, ), self._handle_batch_output, ), @@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ) tokenized_obj = TokenizedGenerateReqInput( - obj.rid, input_text, input_ids, mm_inputs, @@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): obj.top_logprobs_num, obj.token_ids_logprob, obj.stream, + rid=obj.rid, bootstrap_host=obj.bootstrap_host, bootstrap_port=obj.bootstrap_port, bootstrap_room=obj.bootstrap_room, @@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( - obj.rid, input_text, input_ids, mm_inputs, token_type_ids, sampling_params, + rid=obj.rid, priority=obj.priority, ) @@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): def abort_request(self, rid: str = "", abort_all: bool = False): if not abort_all and rid not in self.rid_to_state: return - req = AbortReq(rid, abort_all) + req = AbortReq(rid=rid, abort_all=abort_all) self.send_to_scheduler.send_pyobj(req) if self.enable_metrics: # TODO: also use custom_labels from the request @@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): def _handle_batch_output( self, recv_obj: Union[ - BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut + BatchStrOutput, + BatchEmbeddingOutput, + BatchMultimodalOutput, + BatchTokenIDOutput, ], ): for i, rid in enumerate(recv_obj.rids): @@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): i, ) - if not isinstance(recv_obj, BatchEmbeddingOut): + if not isinstance(recv_obj, BatchEmbeddingOutput): meta_info.update( { "completion_tokens": recv_obj.completion_tokens[i], @@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] - if isinstance(recv_obj, BatchStrOut): + if isinstance(recv_obj, BatchStrOutput): state.text += recv_obj.output_strs[i] if state.obj.stream: state.output_ids.extend(recv_obj.output_ids[i]) @@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): "output_ids": output_token_ids, "meta_info": meta_info, } - elif isinstance(recv_obj, BatchTokenIDOut): + elif isinstance(recv_obj, BatchTokenIDOutput): if self.server_args.stream_output and state.obj.stream: state.output_ids.extend(recv_obj.output_ids[i]) output_token_ids = state.output_ids[state.last_output_offset :] @@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): "output_ids": output_token_ids, "meta_info": meta_info, } - elif isinstance(recv_obj, BatchMultimodalOut): + elif isinstance(recv_obj, BatchMultimodalOutput): raise NotImplementedError("BatchMultimodalOut not implemented") else: - assert isinstance(recv_obj, BatchEmbeddingOut) + assert isinstance(recv_obj, BatchEmbeddingOutput) out_dict = { "embedding": recv_obj.embeddings[i], "meta_info": meta_info, @@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): top_logprobs_num: int, token_ids_logprob: List[int], return_text_in_logprobs: bool, - recv_obj: BatchStrOut, + recv_obj: BatchStrOutput, recv_obj_index: int, ): if recv_obj.input_token_logprobs_val is None: @@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ret.append(None) return ret - def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): + def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): completion_tokens = ( recv_obj.completion_tokens[i] if getattr(recv_obj, "completion_tokens", None) @@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): asyncio.create_task(asyncio.to_thread(background_task)) - def _handle_abort_req(self, recv_obj): + def _handle_abort_req(self, recv_obj: AbortReq): if is_health_check_generate_req(recv_obj): return state = self.rid_to_state[recv_obj.rid]