Introduce naming convention in io_struct and base sglang io classes. (#10133)
This commit is contained in:
@@ -22,8 +22,8 @@ import zmq.asyncio
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
BatchEmbeddingOutput,
|
||||
BatchTokenIDOutput,
|
||||
HealthCheckOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -467,9 +467,9 @@ class GrpcRequestManager:
|
||||
await self.is_pause_cond.wait()
|
||||
|
||||
# Handle different output types
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
if isinstance(recv_obj, BatchTokenIDOutput):
|
||||
await self._handle_batch_output(recv_obj)
|
||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||
elif isinstance(recv_obj, BatchEmbeddingOutput):
|
||||
await self._handle_embedding_output(recv_obj)
|
||||
elif isinstance(recv_obj, HealthCheckOutput):
|
||||
await self._handle_health_check_output(recv_obj)
|
||||
@@ -498,7 +498,7 @@ class GrpcRequestManager:
|
||||
def _convert_logprob_style(
|
||||
self,
|
||||
state: GrpcReqState,
|
||||
batch_out: BatchTokenIDOut,
|
||||
batch_out: BatchTokenIDOutput,
|
||||
batch_index: int,
|
||||
):
|
||||
"""
|
||||
@@ -545,7 +545,7 @@ class GrpcRequestManager:
|
||||
batch_out.output_top_logprobs_idx[batch_index]
|
||||
)
|
||||
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
||||
"""Handle batch generation output from scheduler."""
|
||||
# Process each request in the batch
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
@@ -666,7 +666,7 @@ class GrpcRequestManager:
|
||||
|
||||
asyncio.create_task(cleanup())
|
||||
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
||||
"""Handle batch embedding output from scheduler."""
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
|
||||
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||
MultiTokenizerManager,
|
||||
MultiTokenizerRouter,
|
||||
TokenizerWorker,
|
||||
get_main_process_id,
|
||||
monkey_patch_uvicorn_multiprocessing,
|
||||
read_from_shared_memory,
|
||||
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
class _GlobalState:
|
||||
tokenizer_manager: Union[
|
||||
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
|
||||
]
|
||||
tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
|
||||
template_manager: TemplateManager
|
||||
scheduler_info: Dict
|
||||
|
||||
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
|
||||
)
|
||||
|
||||
# Launch multi-tokenizer manager process
|
||||
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
|
||||
tokenizer_manager = TokenizerWorker(server_args, port_args)
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
|
||||
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
|
||||
get_normalized_target_modules,
|
||||
get_target_module_name,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||
from sglang.srt.managers.io_struct import LoRAUpdateOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import replace_submodule
|
||||
@@ -107,8 +107,8 @@ class LoRAManager:
|
||||
|
||||
def create_lora_update_result(
|
||||
self, success: bool, error_message: str = ""
|
||||
) -> LoRAUpdateResult:
|
||||
return LoRAUpdateResult(
|
||||
) -> LoRAUpdateOutput:
|
||||
return LoRAUpdateOutput(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
loaded_adapters={
|
||||
@@ -117,7 +117,7 @@ class LoRAManager:
|
||||
},
|
||||
)
|
||||
|
||||
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
||||
"""
|
||||
Load a single LoRA adapter from the specified path.
|
||||
|
||||
@@ -174,7 +174,7 @@ class LoRAManager:
|
||||
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
delete the corresponding LoRA modules.
|
||||
|
||||
@@ -26,11 +26,11 @@ import zmq
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
BatchMultimodalOutput,
|
||||
BatchStrOutput,
|
||||
BatchTokenIDOutput,
|
||||
FreezeGCReq,
|
||||
MultiTokenizerRegisterReq,
|
||||
)
|
||||
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOutput, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
(MultiTokenizerRegisterReq, lambda x: x),
|
||||
(FreezeGCReq, self.handle_freeze_gc_req),
|
||||
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
return recv_obj
|
||||
|
||||
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
|
||||
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
# Initialize decode status
|
||||
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
s.sent_offset = len(output_str)
|
||||
output_strs.append(incremental_output)
|
||||
|
||||
return BatchStrOut(
|
||||
return BatchStrOutput(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
outputs = self.tokenizer.detokenize(recv_obj)
|
||||
return BatchMultimodalOut(
|
||||
return BatchMultimodalOutput(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
outputs=outputs,
|
||||
|
||||
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
@@ -36,10 +37,32 @@ else:
|
||||
|
||||
|
||||
# Parameters for a session
|
||||
@dataclass
|
||||
class BaseReq(ABC):
|
||||
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
|
||||
|
||||
def regenerate_rid(self):
|
||||
"""Generate a new request ID and return it."""
|
||||
if isinstance(self.rid, list):
|
||||
self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
|
||||
else:
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseBatchReq(ABC):
|
||||
rids: Optional[List[str]] = field(default=None, kw_only=True)
|
||||
|
||||
def regenerate_rids(self):
|
||||
"""Generate new request IDs and return them."""
|
||||
self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
|
||||
return self.rids
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionParams:
|
||||
id: Optional[str] = None
|
||||
rid: Optional[str] = None
|
||||
offset: Optional[int] = None
|
||||
replace: Optional[bool] = None
|
||||
drop_previous_output: Optional[bool] = None
|
||||
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
class GenerateReqInput(BaseReq):
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can specify either text or input_ids
|
||||
@@ -83,8 +106,6 @@ class GenerateReqInput:
|
||||
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Whether to return logprobs.
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||
@@ -491,11 +512,6 @@ class GenerateReqInput:
|
||||
):
|
||||
raise ValueError("Session params must be a dict or a list of dicts.")
|
||||
|
||||
def regenerate_rid(self):
|
||||
"""Generate a new request ID and return it."""
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
def __getitem__(self, i):
|
||||
return GenerateReqInput(
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
@@ -558,9 +574,7 @@ class GenerateReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedGenerateReqInput:
|
||||
# The request id
|
||||
rid: str
|
||||
class TokenizedGenerateReqInput(BaseReq):
|
||||
# The input text
|
||||
input_text: str
|
||||
# The input token ids
|
||||
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenizedGenerateReqInput:
|
||||
class BatchTokenizedGenerateReqInput(BaseBatchReq):
|
||||
# The batch of tokenized requests
|
||||
batch: List[TokenizedGenerateReqInput]
|
||||
|
||||
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
class EmbeddingReqInput(BaseReq):
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[List[str]], List[str], str]] = None
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
|
||||
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# Dummy input embeds for compatibility
|
||||
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 0
|
||||
|
||||
def regenerate_rid(self):
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedEmbeddingReqInput:
|
||||
# The request id
|
||||
rid: str
|
||||
class TokenizedEmbeddingReqInput(BaseReq):
|
||||
# The input text
|
||||
input_text: str
|
||||
# The input token ids
|
||||
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenizedEmbeddingReqInput:
|
||||
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
|
||||
# The batch of tokenized embedding requests
|
||||
batch: List[TokenizedEmbeddingReqInput]
|
||||
|
||||
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
class BatchTokenIDOutput(BaseBatchReq):
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# For incremental decoding
|
||||
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalDecodeReq:
|
||||
class BatchMultimodalDecodeReq(BaseBatchReq):
|
||||
decoded_ids: List[int]
|
||||
input_token_logprobs_val: List[float]
|
||||
input_token_logprobs_idx: List[int]
|
||||
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
|
||||
image_resolutions: List[List[int]]
|
||||
resize_image_resolutions: List[List[int]]
|
||||
|
||||
# The request id
|
||||
rids: List[str]
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
|
||||
# Token counts
|
||||
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
class BatchStrOutput(BaseBatchReq):
|
||||
# The finish reason
|
||||
finished_reasons: List[dict]
|
||||
# The output decoded strings
|
||||
@@ -909,9 +909,7 @@ class BatchStrOut:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
class BatchMultimodalOutput(BaseBatchReq):
|
||||
# The finish reason
|
||||
finished_reasons: List[dict]
|
||||
decoded_ids: List[List[int]]
|
||||
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
class BatchEmbeddingOutput(BaseBatchReq):
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# The output embedding
|
||||
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClearHiCacheReqInput:
|
||||
class ClearHiCacheReqInput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClearHiCacheReqOutput:
|
||||
class ClearHiCacheReqOutput(BaseReq):
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlushCacheReqInput:
|
||||
class FlushCacheReqInput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlushCacheReqOutput:
|
||||
class FlushCacheReqOutput(BaseReq):
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
class UpdateWeightFromDiskReqInput(BaseReq):
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
# The format to load the weights
|
||||
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqOutput:
|
||||
class UpdateWeightFromDiskReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
# Number of paused requests during weight sync.
|
||||
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqInput:
|
||||
class UpdateWeightsFromDistributedReqInput(BaseReq):
|
||||
names: List[str]
|
||||
dtypes: List[str]
|
||||
shapes: List[List[int]]
|
||||
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqOutput:
|
||||
class UpdateWeightsFromDistributedReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
class UpdateWeightsFromTensorReqInput(BaseReq):
|
||||
"""Update model weights from tensor input.
|
||||
|
||||
- Tensors are serialized for transmission
|
||||
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqOutput:
|
||||
class UpdateWeightsFromTensorReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqInput:
|
||||
class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
|
||||
# The master address
|
||||
master_address: str
|
||||
# The ports for each rank's communication group
|
||||
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqOutput:
|
||||
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendWeightsToRemoteInstanceReqInput:
|
||||
class SendWeightsToRemoteInstanceReqInput(BaseReq):
|
||||
# The master address
|
||||
master_address: str
|
||||
# The ports for each rank's communication group
|
||||
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendWeightsToRemoteInstanceReqOutput:
|
||||
class SendWeightsToRemoteInstanceReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqInput:
|
||||
class InitWeightsUpdateGroupReqInput(BaseReq):
|
||||
# The master address
|
||||
master_address: str
|
||||
# The master port
|
||||
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqOutput:
|
||||
class InitWeightsUpdateGroupReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DestroyWeightsUpdateGroupReqInput:
|
||||
class DestroyWeightsUpdateGroupReqInput(BaseReq):
|
||||
group_name: str = "weight_update_group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DestroyWeightsUpdateGroupReqOutput:
|
||||
class DestroyWeightsUpdateGroupReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightVersionReqInput:
|
||||
class UpdateWeightVersionReqInput(BaseReq):
|
||||
# The new weight version
|
||||
new_version: str
|
||||
# Whether to abort all running requests before updating
|
||||
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqInput:
|
||||
class GetWeightsByNameReqInput(BaseReq):
|
||||
name: str
|
||||
truncate_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqOutput:
|
||||
class GetWeightsByNameReqOutput(BaseReq):
|
||||
parameter: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqInput:
|
||||
class ReleaseMemoryOccupationReqInput(BaseReq):
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqOutput:
|
||||
class ReleaseMemoryOccupationReqOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqInput:
|
||||
class ResumeMemoryOccupationReqInput(BaseReq):
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqOutput:
|
||||
class ResumeMemoryOccupationReqOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlowDownReqInput:
|
||||
class SlowDownReqInput(BaseReq):
|
||||
forward_sleep_time: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlowDownReqOutput:
|
||||
class SlowDownReqOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
rid: str = ""
|
||||
class AbortReq(BaseReq):
|
||||
# Whether to abort all requests
|
||||
abort_all: bool = False
|
||||
# The finished reason data
|
||||
finished_reason: Optional[Dict[str, Any]] = None
|
||||
abort_reason: Optional[str] = None
|
||||
# used in MultiTokenzierManager mode
|
||||
rids: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.rids = self.rid
|
||||
# FIXME: This is a hack to keep the same with the old code
|
||||
if self.rid is None:
|
||||
self.rid = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetInternalStateReq:
|
||||
class GetInternalStateReq(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetInternalStateReqOutput:
|
||||
class GetInternalStateReqOutput(BaseReq):
|
||||
internal_state: Dict[Any, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReq:
|
||||
class SetInternalStateReq(BaseReq):
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReqOutput:
|
||||
class SetInternalStateReqOutput(BaseReq):
|
||||
updated: bool
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqInput:
|
||||
class ProfileReqInput(BaseReq):
|
||||
# The output directory
|
||||
output_dir: Optional[str] = None
|
||||
# If set, it profile as many as this number of steps.
|
||||
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
class ProfileReq(BaseReq):
|
||||
type: ProfileReqType
|
||||
output_dir: Optional[str] = None
|
||||
start_step: Optional[int] = None
|
||||
@@ -1238,18 +1232,18 @@ class ProfileReq:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqOutput:
|
||||
class ProfileReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeGCReq:
|
||||
class FreezeGCReq(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigureLoggingReq:
|
||||
class ConfigureLoggingReq(BaseReq):
|
||||
log_requests: Optional[bool] = None
|
||||
log_requests_level: Optional[int] = None
|
||||
dump_requests_folder: Optional[str] = None
|
||||
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqInput:
|
||||
class OpenSessionReqInput(BaseReq):
|
||||
capacity_of_str_len: int
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloseSessionReqInput:
|
||||
class CloseSessionReqInput(BaseReq):
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqOutput:
|
||||
class OpenSessionReqOutput(BaseReq):
|
||||
session_id: Optional[str]
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckOutput:
|
||||
class HealthCheckOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
class ExpertDistributionReq(Enum):
|
||||
class ExpertDistributionReqType(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
class ExpertDistributionReq(BaseReq):
|
||||
action: ExpertDistributionReqType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertDistributionReqOutput:
|
||||
class ExpertDistributionReqOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1304,7 +1302,7 @@ class Tool:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseFunctionCallReq:
|
||||
class ParseFunctionCallReq(BaseReq):
|
||||
text: str # The text to parse.
|
||||
tools: List[Tool] = field(
|
||||
default_factory=list
|
||||
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeparateReasoningReqInput:
|
||||
class SeparateReasoningReqInput(BaseReq):
|
||||
text: str # The text to parse.
|
||||
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexGenerateReqInput:
|
||||
class VertexGenerateReqInput(BaseReq):
|
||||
instances: List[dict]
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcReqInput:
|
||||
class RpcReqInput(BaseReq):
|
||||
method: str
|
||||
parameters: Optional[Dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcReqOutput:
|
||||
class RpcReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadLoRAAdapterReqInput:
|
||||
class LoadLoRAAdapterReqInput(BaseReq):
|
||||
# The name of the lora module to newly loaded.
|
||||
lora_name: str
|
||||
# The path of loading.
|
||||
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnloadLoRAAdapterReqInput:
|
||||
class UnloadLoRAAdapterReqInput(BaseReq):
|
||||
# The name of lora module to unload.
|
||||
lora_name: str
|
||||
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAUpdateResult:
|
||||
class LoRAUpdateOutput(BaseReq):
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerRegisterReq:
|
||||
rids: Optional[Union[List[str], str]] = None
|
||||
class MultiTokenizerRegisterReq(BaseBatchReq):
|
||||
ipc_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerWrapper:
|
||||
# FIXME(lsyin): remove this
|
||||
worker_id: int
|
||||
obj: Optional[Any] = None
|
||||
|
||||
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockReqInput:
|
||||
class BlockReqInput(BaseReq):
|
||||
type: BlockReqType
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetLoadReqInput:
|
||||
class GetLoadReqInput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetLoadReqOutput:
|
||||
class GetLoadReqOutput(BaseReq):
|
||||
dp_rank: int
|
||||
num_reqs: int
|
||||
num_waiting_reqs: int
|
||||
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchLoadUpdateReq:
|
||||
class WatchLoadUpdateReq(BaseReq):
|
||||
loads: List[GetLoadReqOutput]
|
||||
|
||||
|
||||
def _check_all_req_types():
|
||||
"""A helper function to check all request types are defined in this file."""
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||
for class_type in all_classes:
|
||||
# check its name
|
||||
name = class_type[0]
|
||||
is_io_struct = (
|
||||
name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
|
||||
)
|
||||
is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
|
||||
class_type[1], BaseBatchReq
|
||||
)
|
||||
if is_io_struct and not is_base_req:
|
||||
raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
|
||||
if is_base_req and not is_io_struct:
|
||||
raise ValueError(
|
||||
f"{name} is a subclass of BaseReq but not follow the naming convention."
|
||||
)
|
||||
|
||||
|
||||
_check_all_req_types()
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
||||
"""Mixin class and utils for multi-http-worker mode"""
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
@@ -30,10 +30,10 @@ import zmq.asyncio
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchStrOutput,
|
||||
BatchTokenIDOutput,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWrapper,
|
||||
)
|
||||
@@ -83,8 +83,8 @@ class SocketMapping:
|
||||
|
||||
def _handle_output_by_index(output, i):
|
||||
"""NOTE: A maintainable method is better here."""
|
||||
if isinstance(output, BatchTokenIDOut):
|
||||
new_output = BatchTokenIDOut(
|
||||
if isinstance(output, BatchTokenIDOutput):
|
||||
new_output = BatchTokenIDOutput(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
elif isinstance(output, BatchEmbeddingOut):
|
||||
new_output = BatchEmbeddingOut(
|
||||
elif isinstance(output, BatchEmbeddingOutput):
|
||||
new_output = BatchEmbeddingOutput(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
elif isinstance(output, BatchStrOut):
|
||||
new_output = BatchStrOut(
|
||||
elif isinstance(output, BatchStrOutput):
|
||||
new_output = BatchStrOutput(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
elif isinstance(output, BatchMultimodalOut):
|
||||
new_output = BatchMultimodalOut(
|
||||
elif isinstance(output, BatchMultimodalOutput):
|
||||
new_output = BatchMultimodalOutput(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
|
||||
|
||||
|
||||
class MultiHttpWorkerDetokenizerMixin:
|
||||
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
||||
"""Mixin class for DetokenizerManager"""
|
||||
|
||||
def get_worker_ids_from_req_rids(self, rids):
|
||||
if isinstance(rids, list):
|
||||
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
|
||||
|
||||
|
||||
class MultiTokenizerRouter:
|
||||
"""A router to receive requests from MultiTokenizerManager"""
|
||||
"""A router to receive requests from TokenizerWorker"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
|
||||
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
||||
|
||||
|
||||
class MultiTokenizerManager(TokenizerManager):
|
||||
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
||||
class TokenizerWorker(TokenizerManager):
|
||||
"""Tokenizer Worker in multi-http-worker mode"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
ExpertDistributionReqType,
|
||||
FlushCacheReqInput,
|
||||
FlushCacheReqOutput,
|
||||
FreezeGCReq,
|
||||
@@ -1487,12 +1488,12 @@ class Scheduler(
|
||||
req.priority = -sys.maxsize - 1
|
||||
elif not self.enable_priority_scheduling and req.priority is not None:
|
||||
abort_req = AbortReq(
|
||||
req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
||||
},
|
||||
rid=req.rid,
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
|
||||
@@ -1528,12 +1529,12 @@ class Scheduler(
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(
|
||||
req_to_abort.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": message,
|
||||
},
|
||||
rid=req_to_abort.rid,
|
||||
)
|
||||
)
|
||||
return req_to_abort.rid == recv_req.rid
|
||||
@@ -2005,7 +2006,7 @@ class Scheduler(
|
||||
self.new_token_ratio = new_token_ratio
|
||||
for req in reqs_to_abort:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(req.rid, abort_reason=req.to_abort_message)
|
||||
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -2575,7 +2576,7 @@ class Scheduler(
|
||||
if self.enable_hicache_storage:
|
||||
# to release prefetch events associated with the request
|
||||
self.tree_cache.release_aborted_request(req.rid)
|
||||
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
||||
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
||||
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
@@ -2687,11 +2688,12 @@ class Scheduler(
|
||||
return SlowDownReqOutput()
|
||||
|
||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||
action = recv_req.action
|
||||
if action == ExpertDistributionReqType.START_RECORD:
|
||||
get_global_expert_distribution_recorder().start_record()
|
||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||
elif action == ExpertDistributionReqType.STOP_RECORD:
|
||||
get_global_expert_distribution_recorder().stop_record()
|
||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
||||
get_global_expert_distribution_recorder().dump_record()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
||||
@@ -2774,7 +2776,8 @@ class IdleSleeper:
|
||||
|
||||
|
||||
def is_health_check_generate_req(recv_req):
|
||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||
rid = getattr(recv_req, "rid", None)
|
||||
return rid is not None and rid.startswith("HEALTH_CHECK")
|
||||
|
||||
|
||||
def is_work_request(recv_req):
|
||||
|
||||
@@ -9,7 +9,11 @@ import torch
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOutput,
|
||||
BatchTokenIDOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
|
||||
logger.error(
|
||||
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
||||
)
|
||||
self.abort_request(AbortReq(req.rid))
|
||||
self.abort_request(AbortReq(rid=req.rid))
|
||||
req.grammar.finished = req.finished()
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
|
||||
logger.error(
|
||||
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
||||
)
|
||||
self.abort_request(AbortReq(req.rid))
|
||||
self.abort_request(AbortReq(rid=req.rid))
|
||||
req.grammar.finished = req.finished()
|
||||
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
|
||||
return
|
||||
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
rids,
|
||||
BatchTokenIDOutput(
|
||||
finished_reasons,
|
||||
decoded_texts,
|
||||
decode_ids_list,
|
||||
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
|
||||
output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx,
|
||||
output_hidden_states,
|
||||
rids=rids,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
rids,
|
||||
BatchEmbeddingOutput(
|
||||
finished_reasons,
|
||||
embeddings,
|
||||
prompt_tokens,
|
||||
cached_tokens,
|
||||
rids=rids,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqOutput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
ExpertDistributionReqType,
|
||||
FlushCacheReqInput,
|
||||
FlushCacheReqOutput,
|
||||
GetInternalStateReq,
|
||||
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
LoRAUpdateOutput,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
ProfileReq,
|
||||
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
LoRAUpdateResult,
|
||||
LoRAUpdateOutput,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
|
||||
|
||||
async def start_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
||||
req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD)
|
||||
await self.expert_distribution_communicator(req)
|
||||
|
||||
async def stop_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
||||
req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD)
|
||||
await self.expert_distribution_communicator(req)
|
||||
|
||||
async def dump_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD)
|
||||
await self.expert_distribution_communicator(req)
|
||||
|
||||
async def init_weights_update_group(
|
||||
self: TokenizerManager,
|
||||
|
||||
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||
from sglang.srt.lora.lora_registry import LoRARegistry
|
||||
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchStrOutput,
|
||||
BatchTokenIDOutput,
|
||||
BatchTokenizedEmbeddingReqInput,
|
||||
BatchTokenizedGenerateReqInput,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
FreezeGCReq,
|
||||
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
|
||||
GetLoadReqInput,
|
||||
HealthCheckOutput,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
SessionParams,
|
||||
TokenizedEmbeddingReqInput,
|
||||
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
[
|
||||
(
|
||||
(
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOutput,
|
||||
BatchEmbeddingOutput,
|
||||
BatchTokenIDOutput,
|
||||
BatchMultimodalOutput,
|
||||
),
|
||||
self._handle_batch_output,
|
||||
),
|
||||
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
mm_inputs,
|
||||
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
obj.top_logprobs_num,
|
||||
obj.token_ids_logprob,
|
||||
obj.stream,
|
||||
rid=obj.rid,
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_port=obj.bootstrap_port,
|
||||
bootstrap_room=obj.bootstrap_room,
|
||||
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
mm_inputs,
|
||||
token_type_ids,
|
||||
sampling_params,
|
||||
rid=obj.rid,
|
||||
priority=obj.priority,
|
||||
)
|
||||
|
||||
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||
if not abort_all and rid not in self.rid_to_state:
|
||||
return
|
||||
req = AbortReq(rid, abort_all)
|
||||
req = AbortReq(rid=rid, abort_all=abort_all)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
if self.enable_metrics:
|
||||
# TODO: also use custom_labels from the request
|
||||
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
def _handle_batch_output(
|
||||
self,
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
||||
BatchStrOutput,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchTokenIDOutput,
|
||||
],
|
||||
):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
i,
|
||||
)
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
if not isinstance(recv_obj, BatchEmbeddingOutput):
|
||||
meta_info.update(
|
||||
{
|
||||
"completion_tokens": recv_obj.completion_tokens[i],
|
||||
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
if getattr(recv_obj, "output_hidden_states", None):
|
||||
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
if isinstance(recv_obj, BatchStrOutput):
|
||||
state.text += recv_obj.output_strs[i]
|
||||
if state.obj.stream:
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
"output_ids": output_token_ids,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
elif isinstance(recv_obj, BatchTokenIDOutput):
|
||||
if self.server_args.stream_output and state.obj.stream:
|
||||
state.output_ids.extend(recv_obj.output_ids[i])
|
||||
output_token_ids = state.output_ids[state.last_output_offset :]
|
||||
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
"output_ids": output_token_ids,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
elif isinstance(recv_obj, BatchMultimodalOutput):
|
||||
raise NotImplementedError("BatchMultimodalOut not implemented")
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
assert isinstance(recv_obj, BatchEmbeddingOutput)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
top_logprobs_num: int,
|
||||
token_ids_logprob: List[int],
|
||||
return_text_in_logprobs: bool,
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj: BatchStrOutput,
|
||||
recv_obj_index: int,
|
||||
):
|
||||
if recv_obj.input_token_logprobs_val is None:
|
||||
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
ret.append(None)
|
||||
return ret
|
||||
|
||||
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
|
||||
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
||||
completion_tokens = (
|
||||
recv_obj.completion_tokens[i]
|
||||
if getattr(recv_obj, "completion_tokens", None)
|
||||
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
def _handle_abort_req(self, recv_obj: AbortReq):
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
|
||||
Reference in New Issue
Block a user