Introduce naming convention in io_struct and base sglang io classes. (#10133)
This commit is contained in:
@@ -22,8 +22,8 @@ import zmq.asyncio
|
|||||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOutput,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOutput,
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -467,9 +467,9 @@ class GrpcRequestManager:
|
|||||||
await self.is_pause_cond.wait()
|
await self.is_pause_cond.wait()
|
||||||
|
|
||||||
# Handle different output types
|
# Handle different output types
|
||||||
if isinstance(recv_obj, BatchTokenIDOut):
|
if isinstance(recv_obj, BatchTokenIDOutput):
|
||||||
await self._handle_batch_output(recv_obj)
|
await self._handle_batch_output(recv_obj)
|
||||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
elif isinstance(recv_obj, BatchEmbeddingOutput):
|
||||||
await self._handle_embedding_output(recv_obj)
|
await self._handle_embedding_output(recv_obj)
|
||||||
elif isinstance(recv_obj, HealthCheckOutput):
|
elif isinstance(recv_obj, HealthCheckOutput):
|
||||||
await self._handle_health_check_output(recv_obj)
|
await self._handle_health_check_output(recv_obj)
|
||||||
@@ -498,7 +498,7 @@ class GrpcRequestManager:
|
|||||||
def _convert_logprob_style(
|
def _convert_logprob_style(
|
||||||
self,
|
self,
|
||||||
state: GrpcReqState,
|
state: GrpcReqState,
|
||||||
batch_out: BatchTokenIDOut,
|
batch_out: BatchTokenIDOutput,
|
||||||
batch_index: int,
|
batch_index: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -545,7 +545,7 @@ class GrpcRequestManager:
|
|||||||
batch_out.output_top_logprobs_idx[batch_index]
|
batch_out.output_top_logprobs_idx[batch_index]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
||||||
"""Handle batch generation output from scheduler."""
|
"""Handle batch generation output from scheduler."""
|
||||||
# Process each request in the batch
|
# Process each request in the batch
|
||||||
for i, rid in enumerate(batch_out.rids):
|
for i, rid in enumerate(batch_out.rids):
|
||||||
@@ -666,7 +666,7 @@ class GrpcRequestManager:
|
|||||||
|
|
||||||
asyncio.create_task(cleanup())
|
asyncio.create_task(cleanup())
|
||||||
|
|
||||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
||||||
"""Handle batch embedding output from scheduler."""
|
"""Handle batch embedding output from scheduler."""
|
||||||
for i, rid in enumerate(batch_out.rids):
|
for i, rid in enumerate(batch_out.rids):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||||
MultiTokenizerManager,
|
|
||||||
MultiTokenizerRouter,
|
MultiTokenizerRouter,
|
||||||
|
TokenizerWorker,
|
||||||
get_main_process_id,
|
get_main_process_id,
|
||||||
monkey_patch_uvicorn_multiprocessing,
|
monkey_patch_uvicorn_multiprocessing,
|
||||||
read_from_shared_memory,
|
read_from_shared_memory,
|
||||||
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
|||||||
# Store global states
|
# Store global states
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _GlobalState:
|
class _GlobalState:
|
||||||
tokenizer_manager: Union[
|
tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
|
||||||
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
|
|
||||||
]
|
|
||||||
template_manager: TemplateManager
|
template_manager: TemplateManager
|
||||||
scheduler_info: Dict
|
scheduler_info: Dict
|
||||||
|
|
||||||
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Launch multi-tokenizer manager process
|
# Launch multi-tokenizer manager process
|
||||||
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
|
tokenizer_manager = TokenizerWorker(server_args, port_args)
|
||||||
template_manager = TemplateManager()
|
template_manager = TemplateManager()
|
||||||
template_manager.initialize_templates(
|
template_manager.initialize_templates(
|
||||||
tokenizer_manager=tokenizer_manager,
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
|
|||||||
get_normalized_target_modules,
|
get_normalized_target_modules,
|
||||||
get_target_module_name,
|
get_target_module_name,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
from sglang.srt.managers.io_struct import LoRAUpdateOutput
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import replace_submodule
|
from sglang.srt.utils import replace_submodule
|
||||||
@@ -107,8 +107,8 @@ class LoRAManager:
|
|||||||
|
|
||||||
def create_lora_update_result(
|
def create_lora_update_result(
|
||||||
self, success: bool, error_message: str = ""
|
self, success: bool, error_message: str = ""
|
||||||
) -> LoRAUpdateResult:
|
) -> LoRAUpdateOutput:
|
||||||
return LoRAUpdateResult(
|
return LoRAUpdateOutput(
|
||||||
success=success,
|
success=success,
|
||||||
error_message=error_message,
|
error_message=error_message,
|
||||||
loaded_adapters={
|
loaded_adapters={
|
||||||
@@ -117,7 +117,7 @@ class LoRAManager:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
||||||
"""
|
"""
|
||||||
Load a single LoRA adapter from the specified path.
|
Load a single LoRA adapter from the specified path.
|
||||||
|
|
||||||
@@ -174,7 +174,7 @@ class LoRAManager:
|
|||||||
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
||||||
)
|
)
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
||||||
"""
|
"""
|
||||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||||
delete the corresponding LoRA modules.
|
delete the corresponding LoRA modules.
|
||||||
|
|||||||
@@ -26,11 +26,11 @@ import zmq
|
|||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOutput,
|
||||||
BatchMultimodalDecodeReq,
|
BatchMultimodalDecodeReq,
|
||||||
BatchMultimodalOut,
|
BatchMultimodalOutput,
|
||||||
BatchStrOut,
|
BatchStrOutput,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOutput,
|
||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
MultiTokenizerRegisterReq,
|
MultiTokenizerRegisterReq,
|
||||||
)
|
)
|
||||||
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
|
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
|
||||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
(BatchTokenIDOutput, self.handle_batch_token_id_out),
|
||||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||||
(MultiTokenizerRegisterReq, lambda x: x),
|
(MultiTokenizerRegisterReq, lambda x: x),
|
||||||
(FreezeGCReq, self.handle_freeze_gc_req),
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
||||||
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
return output[:-1]
|
return output[:-1]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
|
||||||
# If it is embedding model, no detokenization is needed.
|
# If it is embedding model, no detokenization is needed.
|
||||||
return recv_obj
|
return recv_obj
|
||||||
|
|
||||||
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
|
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
|
||||||
bs = len(recv_obj.rids)
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
# Initialize decode status
|
# Initialize decode status
|
||||||
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
s.sent_offset = len(output_str)
|
s.sent_offset = len(output_str)
|
||||||
output_strs.append(incremental_output)
|
output_strs.append(incremental_output)
|
||||||
|
|
||||||
return BatchStrOut(
|
return BatchStrOutput(
|
||||||
rids=recv_obj.rids,
|
rids=recv_obj.rids,
|
||||||
finished_reasons=recv_obj.finished_reasons,
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
|
|
||||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||||
outputs = self.tokenizer.detokenize(recv_obj)
|
outputs = self.tokenizer.detokenize(recv_obj)
|
||||||
return BatchMultimodalOut(
|
return BatchMultimodalOutput(
|
||||||
rids=recv_obj.rids,
|
rids=recv_obj.rids,
|
||||||
finished_reasons=recv_obj.finished_reasons,
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
|
from abc import ABC
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
@@ -36,10 +37,32 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Parameters for a session
|
# Parameters for a session
|
||||||
|
@dataclass
|
||||||
|
class BaseReq(ABC):
|
||||||
|
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
|
||||||
|
|
||||||
|
def regenerate_rid(self):
|
||||||
|
"""Generate a new request ID and return it."""
|
||||||
|
if isinstance(self.rid, list):
|
||||||
|
self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
|
||||||
|
else:
|
||||||
|
self.rid = uuid.uuid4().hex
|
||||||
|
return self.rid
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseBatchReq(ABC):
|
||||||
|
rids: Optional[List[str]] = field(default=None, kw_only=True)
|
||||||
|
|
||||||
|
def regenerate_rids(self):
|
||||||
|
"""Generate new request IDs and return them."""
|
||||||
|
self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
|
||||||
|
return self.rids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SessionParams:
|
class SessionParams:
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
rid: Optional[str] = None
|
|
||||||
offset: Optional[int] = None
|
offset: Optional[int] = None
|
||||||
replace: Optional[bool] = None
|
replace: Optional[bool] = None
|
||||||
drop_previous_output: Optional[bool] = None
|
drop_previous_output: Optional[bool] = None
|
||||||
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateReqInput:
|
class GenerateReqInput(BaseReq):
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
text: Optional[Union[List[str], str]] = None
|
text: Optional[Union[List[str], str]] = None
|
||||||
# The token ids for text; one can specify either text or input_ids
|
# The token ids for text; one can specify either text or input_ids
|
||||||
@@ -83,8 +106,6 @@ class GenerateReqInput:
|
|||||||
audio_data: Optional[MultimodalDataInputFormat] = None
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||||
# The sampling_params. See descriptions below.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# The request id.
|
|
||||||
rid: Optional[Union[List[str], str]] = None
|
|
||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
@@ -491,11 +512,6 @@ class GenerateReqInput:
|
|||||||
):
|
):
|
||||||
raise ValueError("Session params must be a dict or a list of dicts.")
|
raise ValueError("Session params must be a dict or a list of dicts.")
|
||||||
|
|
||||||
def regenerate_rid(self):
|
|
||||||
"""Generate a new request ID and return it."""
|
|
||||||
self.rid = uuid.uuid4().hex
|
|
||||||
return self.rid
|
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return GenerateReqInput(
|
return GenerateReqInput(
|
||||||
text=self.text[i] if self.text is not None else None,
|
text=self.text[i] if self.text is not None else None,
|
||||||
@@ -558,9 +574,7 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedGenerateReqInput:
|
class TokenizedGenerateReqInput(BaseReq):
|
||||||
# The request id
|
|
||||||
rid: str
|
|
||||||
# The input text
|
# The input text
|
||||||
input_text: str
|
input_text: str
|
||||||
# The input token ids
|
# The input token ids
|
||||||
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenizedGenerateReqInput:
|
class BatchTokenizedGenerateReqInput(BaseBatchReq):
|
||||||
# The batch of tokenized requests
|
# The batch of tokenized requests
|
||||||
batch: List[TokenizedGenerateReqInput]
|
batch: List[TokenizedGenerateReqInput]
|
||||||
|
|
||||||
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput(BaseReq):
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
text: Optional[Union[List[List[str]], List[str], str]] = None
|
text: Optional[Union[List[List[str]], List[str], str]] = None
|
||||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||||
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
|
|||||||
audio_data: Optional[MultimodalDataInputFormat] = None
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# The request id.
|
|
||||||
rid: Optional[Union[List[str], str]] = None
|
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# Dummy input embeds for compatibility
|
# Dummy input embeds for compatibility
|
||||||
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
|
|||||||
for i in range(self.batch_size):
|
for i in range(self.batch_size):
|
||||||
self.sampling_params[i]["max_new_tokens"] = 0
|
self.sampling_params[i]["max_new_tokens"] = 0
|
||||||
|
|
||||||
def regenerate_rid(self):
|
|
||||||
self.rid = uuid.uuid4().hex
|
|
||||||
return self.rid
|
|
||||||
|
|
||||||
def contains_mm_input(self) -> bool:
|
def contains_mm_input(self) -> bool:
|
||||||
return (
|
return (
|
||||||
has_valid_data(self.image_data)
|
has_valid_data(self.image_data)
|
||||||
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedEmbeddingReqInput:
|
class TokenizedEmbeddingReqInput(BaseReq):
|
||||||
# The request id
|
|
||||||
rid: str
|
|
||||||
# The input text
|
# The input text
|
||||||
input_text: str
|
input_text: str
|
||||||
# The input token ids
|
# The input token ids
|
||||||
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenizedEmbeddingReqInput:
|
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
|
||||||
# The batch of tokenized embedding requests
|
# The batch of tokenized embedding requests
|
||||||
batch: List[TokenizedEmbeddingReqInput]
|
batch: List[TokenizedEmbeddingReqInput]
|
||||||
|
|
||||||
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOutput(BaseBatchReq):
|
||||||
# The request id
|
|
||||||
rids: List[str]
|
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[BaseFinishReason]
|
finished_reasons: List[BaseFinishReason]
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchMultimodalDecodeReq:
|
class BatchMultimodalDecodeReq(BaseBatchReq):
|
||||||
decoded_ids: List[int]
|
decoded_ids: List[int]
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
input_token_logprobs_idx: List[int]
|
input_token_logprobs_idx: List[int]
|
||||||
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
|
|||||||
image_resolutions: List[List[int]]
|
image_resolutions: List[List[int]]
|
||||||
resize_image_resolutions: List[List[int]]
|
resize_image_resolutions: List[List[int]]
|
||||||
|
|
||||||
# The request id
|
|
||||||
rids: List[str]
|
|
||||||
finished_reasons: List[BaseFinishReason]
|
finished_reasons: List[BaseFinishReason]
|
||||||
|
|
||||||
# Token counts
|
# Token counts
|
||||||
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOutput(BaseBatchReq):
|
||||||
# The request id
|
|
||||||
rids: List[str]
|
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[dict]
|
finished_reasons: List[dict]
|
||||||
# The output decoded strings
|
# The output decoded strings
|
||||||
@@ -909,9 +909,7 @@ class BatchStrOut:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchMultimodalOut:
|
class BatchMultimodalOutput(BaseBatchReq):
|
||||||
# The request id
|
|
||||||
rids: List[str]
|
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[dict]
|
finished_reasons: List[dict]
|
||||||
decoded_ids: List[List[int]]
|
decoded_ids: List[List[int]]
|
||||||
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchEmbeddingOut:
|
class BatchEmbeddingOutput(BaseBatchReq):
|
||||||
# The request id
|
|
||||||
rids: List[str]
|
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[BaseFinishReason]
|
finished_reasons: List[BaseFinishReason]
|
||||||
# The output embedding
|
# The output embedding
|
||||||
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClearHiCacheReqInput:
|
class ClearHiCacheReqInput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClearHiCacheReqOutput:
|
class ClearHiCacheReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReqInput:
|
class FlushCacheReqInput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReqOutput:
|
class FlushCacheReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightFromDiskReqInput:
|
class UpdateWeightFromDiskReqInput(BaseReq):
|
||||||
# The model path with the new weights
|
# The model path with the new weights
|
||||||
model_path: str
|
model_path: str
|
||||||
# The format to load the weights
|
# The format to load the weights
|
||||||
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightFromDiskReqOutput:
|
class UpdateWeightFromDiskReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
# Number of paused requests during weight sync.
|
# Number of paused requests during weight sync.
|
||||||
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromDistributedReqInput:
|
class UpdateWeightsFromDistributedReqInput(BaseReq):
|
||||||
names: List[str]
|
names: List[str]
|
||||||
dtypes: List[str]
|
dtypes: List[str]
|
||||||
shapes: List[List[int]]
|
shapes: List[List[int]]
|
||||||
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromDistributedReqOutput:
|
class UpdateWeightsFromDistributedReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromTensorReqInput:
|
class UpdateWeightsFromTensorReqInput(BaseReq):
|
||||||
"""Update model weights from tensor input.
|
"""Update model weights from tensor input.
|
||||||
|
|
||||||
- Tensors are serialized for transmission
|
- Tensors are serialized for transmission
|
||||||
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromTensorReqOutput:
|
class UpdateWeightsFromTensorReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsSendGroupForRemoteInstanceReqInput:
|
class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
|
||||||
# The master address
|
# The master address
|
||||||
master_address: str
|
master_address: str
|
||||||
# The ports for each rank's communication group
|
# The ports for each rank's communication group
|
||||||
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsSendGroupForRemoteInstanceReqOutput:
|
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SendWeightsToRemoteInstanceReqInput:
|
class SendWeightsToRemoteInstanceReqInput(BaseReq):
|
||||||
# The master address
|
# The master address
|
||||||
master_address: str
|
master_address: str
|
||||||
# The ports for each rank's communication group
|
# The ports for each rank's communication group
|
||||||
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SendWeightsToRemoteInstanceReqOutput:
|
class SendWeightsToRemoteInstanceReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsUpdateGroupReqInput:
|
class InitWeightsUpdateGroupReqInput(BaseReq):
|
||||||
# The master address
|
# The master address
|
||||||
master_address: str
|
master_address: str
|
||||||
# The master port
|
# The master port
|
||||||
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsUpdateGroupReqOutput:
|
class InitWeightsUpdateGroupReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DestroyWeightsUpdateGroupReqInput:
|
class DestroyWeightsUpdateGroupReqInput(BaseReq):
|
||||||
group_name: str = "weight_update_group"
|
group_name: str = "weight_update_group"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DestroyWeightsUpdateGroupReqOutput:
|
class DestroyWeightsUpdateGroupReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightVersionReqInput:
|
class UpdateWeightVersionReqInput(BaseReq):
|
||||||
# The new weight version
|
# The new weight version
|
||||||
new_version: str
|
new_version: str
|
||||||
# Whether to abort all running requests before updating
|
# Whether to abort all running requests before updating
|
||||||
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqInput:
|
class GetWeightsByNameReqInput(BaseReq):
|
||||||
name: str
|
name: str
|
||||||
truncate_size: int = 100
|
truncate_size: int = 100
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqOutput:
|
class GetWeightsByNameReqOutput(BaseReq):
|
||||||
parameter: list
|
parameter: list
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReleaseMemoryOccupationReqInput:
|
class ReleaseMemoryOccupationReqInput(BaseReq):
|
||||||
# Optional tags to identify the memory region, which is primarily used for RL
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
# Currently we only support `weights` and `kv_cache`
|
# Currently we only support `weights` and `kv_cache`
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReleaseMemoryOccupationReqOutput:
|
class ReleaseMemoryOccupationReqOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResumeMemoryOccupationReqInput:
|
class ResumeMemoryOccupationReqInput(BaseReq):
|
||||||
# Optional tags to identify the memory region, which is primarily used for RL
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
# Currently we only support `weights` and `kv_cache`
|
# Currently we only support `weights` and `kv_cache`
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResumeMemoryOccupationReqOutput:
|
class ResumeMemoryOccupationReqOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlowDownReqInput:
|
class SlowDownReqInput(BaseReq):
|
||||||
forward_sleep_time: Optional[float]
|
forward_sleep_time: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlowDownReqOutput:
|
class SlowDownReqOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbortReq:
|
class AbortReq(BaseReq):
|
||||||
# The request id
|
|
||||||
rid: str = ""
|
|
||||||
# Whether to abort all requests
|
# Whether to abort all requests
|
||||||
abort_all: bool = False
|
abort_all: bool = False
|
||||||
# The finished reason data
|
# The finished reason data
|
||||||
finished_reason: Optional[Dict[str, Any]] = None
|
finished_reason: Optional[Dict[str, Any]] = None
|
||||||
abort_reason: Optional[str] = None
|
abort_reason: Optional[str] = None
|
||||||
# used in MultiTokenzierManager mode
|
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.rids = self.rid
|
# FIXME: This is a hack to keep the same with the old code
|
||||||
|
if self.rid is None:
|
||||||
|
self.rid = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetInternalStateReq:
|
class GetInternalStateReq(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetInternalStateReqOutput:
|
class GetInternalStateReqOutput(BaseReq):
|
||||||
internal_state: Dict[Any, Any]
|
internal_state: Dict[Any, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetInternalStateReq:
|
class SetInternalStateReq(BaseReq):
|
||||||
server_args: Dict[str, Any]
|
server_args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetInternalStateReqOutput:
|
class SetInternalStateReqOutput(BaseReq):
|
||||||
updated: bool
|
updated: bool
|
||||||
server_args: Dict[str, Any]
|
server_args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReqInput:
|
class ProfileReqInput(BaseReq):
|
||||||
# The output directory
|
# The output directory
|
||||||
output_dir: Optional[str] = None
|
output_dir: Optional[str] = None
|
||||||
# If set, it profile as many as this number of steps.
|
# If set, it profile as many as this number of steps.
|
||||||
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReq:
|
class ProfileReq(BaseReq):
|
||||||
type: ProfileReqType
|
type: ProfileReqType
|
||||||
output_dir: Optional[str] = None
|
output_dir: Optional[str] = None
|
||||||
start_step: Optional[int] = None
|
start_step: Optional[int] = None
|
||||||
@@ -1238,18 +1232,18 @@ class ProfileReq:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReqOutput:
|
class ProfileReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FreezeGCReq:
|
class FreezeGCReq(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigureLoggingReq:
|
class ConfigureLoggingReq(BaseReq):
|
||||||
log_requests: Optional[bool] = None
|
log_requests: Optional[bool] = None
|
||||||
log_requests_level: Optional[int] = None
|
log_requests_level: Optional[int] = None
|
||||||
dump_requests_folder: Optional[str] = None
|
dump_requests_folder: Optional[str] = None
|
||||||
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqInput:
|
class OpenSessionReqInput(BaseReq):
|
||||||
capacity_of_str_len: int
|
capacity_of_str_len: int
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CloseSessionReqInput:
|
class CloseSessionReqInput(BaseReq):
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqOutput:
|
class OpenSessionReqOutput(BaseReq):
|
||||||
session_id: Optional[str]
|
session_id: Optional[str]
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HealthCheckOutput:
|
class HealthCheckOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ExpertDistributionReq(Enum):
|
class ExpertDistributionReqType(Enum):
|
||||||
START_RECORD = 1
|
START_RECORD = 1
|
||||||
STOP_RECORD = 2
|
STOP_RECORD = 2
|
||||||
DUMP_RECORD = 3
|
DUMP_RECORD = 3
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertDistributionReq(BaseReq):
|
||||||
|
action: ExpertDistributionReqType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpertDistributionReqOutput:
|
class ExpertDistributionReqOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -1304,7 +1302,7 @@ class Tool:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParseFunctionCallReq:
|
class ParseFunctionCallReq(BaseReq):
|
||||||
text: str # The text to parse.
|
text: str # The text to parse.
|
||||||
tools: List[Tool] = field(
|
tools: List[Tool] = field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SeparateReasoningReqInput:
|
class SeparateReasoningReqInput(BaseReq):
|
||||||
text: str # The text to parse.
|
text: str # The text to parse.
|
||||||
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VertexGenerateReqInput:
|
class VertexGenerateReqInput(BaseReq):
|
||||||
instances: List[dict]
|
instances: List[dict]
|
||||||
parameters: Optional[dict] = None
|
parameters: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RpcReqInput:
|
class RpcReqInput(BaseReq):
|
||||||
method: str
|
method: str
|
||||||
parameters: Optional[Dict] = None
|
parameters: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RpcReqOutput:
|
class RpcReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadLoRAAdapterReqInput:
|
class LoadLoRAAdapterReqInput(BaseReq):
|
||||||
# The name of the lora module to newly loaded.
|
# The name of the lora module to newly loaded.
|
||||||
lora_name: str
|
lora_name: str
|
||||||
# The path of loading.
|
# The path of loading.
|
||||||
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnloadLoRAAdapterReqInput:
|
class UnloadLoRAAdapterReqInput(BaseReq):
|
||||||
# The name of lora module to unload.
|
# The name of lora module to unload.
|
||||||
lora_name: str
|
lora_name: str
|
||||||
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||||
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAUpdateResult:
|
class LoRAUpdateOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiTokenizerRegisterReq:
|
class MultiTokenizerRegisterReq(BaseBatchReq):
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
ipc_name: Optional[str] = None
|
ipc_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiTokenizerWrapper:
|
class MultiTokenizerWrapper:
|
||||||
|
# FIXME(lsyin): remove this
|
||||||
worker_id: int
|
worker_id: int
|
||||||
obj: Optional[Any] = None
|
obj: Optional[Any] = None
|
||||||
|
|
||||||
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BlockReqInput:
|
class BlockReqInput(BaseReq):
|
||||||
type: BlockReqType
|
type: BlockReqType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetLoadReqInput:
|
class GetLoadReqInput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetLoadReqOutput:
|
class GetLoadReqOutput(BaseReq):
|
||||||
dp_rank: int
|
dp_rank: int
|
||||||
num_reqs: int
|
num_reqs: int
|
||||||
num_waiting_reqs: int
|
num_waiting_reqs: int
|
||||||
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatchLoadUpdateReq:
|
class WatchLoadUpdateReq(BaseReq):
|
||||||
loads: List[GetLoadReqOutput]
|
loads: List[GetLoadReqOutput]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_all_req_types():
|
||||||
|
"""A helper function to check all request types are defined in this file."""
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
|
||||||
|
all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||||
|
for class_type in all_classes:
|
||||||
|
# check its name
|
||||||
|
name = class_type[0]
|
||||||
|
is_io_struct = (
|
||||||
|
name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
|
||||||
|
)
|
||||||
|
is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
|
||||||
|
class_type[1], BaseBatchReq
|
||||||
|
)
|
||||||
|
if is_io_struct and not is_base_req:
|
||||||
|
raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
|
||||||
|
if is_base_req and not is_io_struct:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} is a subclass of BaseReq but not follow the naming convention."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_check_all_req_types()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
"""Mixin class and utils for multi-http-worker mode"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as multiprocessing
|
import multiprocessing as multiprocessing
|
||||||
@@ -30,10 +30,10 @@ import zmq.asyncio
|
|||||||
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
||||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOutput,
|
||||||
BatchMultimodalOut,
|
BatchMultimodalOutput,
|
||||||
BatchStrOut,
|
BatchStrOutput,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOutput,
|
||||||
MultiTokenizerRegisterReq,
|
MultiTokenizerRegisterReq,
|
||||||
MultiTokenizerWrapper,
|
MultiTokenizerWrapper,
|
||||||
)
|
)
|
||||||
@@ -83,8 +83,8 @@ class SocketMapping:
|
|||||||
|
|
||||||
def _handle_output_by_index(output, i):
|
def _handle_output_by_index(output, i):
|
||||||
"""NOTE: A maintainable method is better here."""
|
"""NOTE: A maintainable method is better here."""
|
||||||
if isinstance(output, BatchTokenIDOut):
|
if isinstance(output, BatchTokenIDOutput):
|
||||||
new_output = BatchTokenIDOut(
|
new_output = BatchTokenIDOutput(
|
||||||
rids=[output.rids[i]],
|
rids=[output.rids[i]],
|
||||||
finished_reasons=(
|
finished_reasons=(
|
||||||
[output.finished_reasons[i]]
|
[output.finished_reasons[i]]
|
||||||
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
|
|||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
)
|
)
|
||||||
elif isinstance(output, BatchEmbeddingOut):
|
elif isinstance(output, BatchEmbeddingOutput):
|
||||||
new_output = BatchEmbeddingOut(
|
new_output = BatchEmbeddingOutput(
|
||||||
rids=[output.rids[i]],
|
rids=[output.rids[i]],
|
||||||
finished_reasons=(
|
finished_reasons=(
|
||||||
[output.finished_reasons[i]]
|
[output.finished_reasons[i]]
|
||||||
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
|
|||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
)
|
)
|
||||||
elif isinstance(output, BatchStrOut):
|
elif isinstance(output, BatchStrOutput):
|
||||||
new_output = BatchStrOut(
|
new_output = BatchStrOutput(
|
||||||
rids=[output.rids[i]],
|
rids=[output.rids[i]],
|
||||||
finished_reasons=(
|
finished_reasons=(
|
||||||
[output.finished_reasons[i]]
|
[output.finished_reasons[i]]
|
||||||
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
|
|||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
)
|
)
|
||||||
elif isinstance(output, BatchMultimodalOut):
|
elif isinstance(output, BatchMultimodalOutput):
|
||||||
new_output = BatchMultimodalOut(
|
new_output = BatchMultimodalOutput(
|
||||||
rids=[output.rids[i]],
|
rids=[output.rids[i]],
|
||||||
finished_reasons=(
|
finished_reasons=(
|
||||||
[output.finished_reasons[i]]
|
[output.finished_reasons[i]]
|
||||||
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHttpWorkerDetokenizerMixin:
|
class MultiHttpWorkerDetokenizerMixin:
|
||||||
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
"""Mixin class for DetokenizerManager"""
|
||||||
|
|
||||||
def get_worker_ids_from_req_rids(self, rids):
|
def get_worker_ids_from_req_rids(self, rids):
|
||||||
if isinstance(rids, list):
|
if isinstance(rids, list):
|
||||||
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
|
|||||||
|
|
||||||
|
|
||||||
class MultiTokenizerRouter:
|
class MultiTokenizerRouter:
|
||||||
"""A router to receive requests from MultiTokenizerManager"""
|
"""A router to receive requests from TokenizerWorker"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
|
|||||||
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
||||||
|
|
||||||
|
|
||||||
class MultiTokenizerManager(TokenizerManager):
|
class TokenizerWorker(TokenizerManager):
|
||||||
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
"""Tokenizer Worker in multi-http-worker mode"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
DestroyWeightsUpdateGroupReqInput,
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
|
ExpertDistributionReqType,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
FlushCacheReqOutput,
|
FlushCacheReqOutput,
|
||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
@@ -1487,12 +1488,12 @@ class Scheduler(
|
|||||||
req.priority = -sys.maxsize - 1
|
req.priority = -sys.maxsize - 1
|
||||||
elif not self.enable_priority_scheduling and req.priority is not None:
|
elif not self.enable_priority_scheduling and req.priority is not None:
|
||||||
abort_req = AbortReq(
|
abort_req = AbortReq(
|
||||||
req.rid,
|
|
||||||
finished_reason={
|
finished_reason={
|
||||||
"type": "abort",
|
"type": "abort",
|
||||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
||||||
},
|
},
|
||||||
|
rid=req.rid,
|
||||||
)
|
)
|
||||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||||
|
|
||||||
@@ -1528,12 +1529,12 @@ class Scheduler(
|
|||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
AbortReq(
|
AbortReq(
|
||||||
req_to_abort.rid,
|
|
||||||
finished_reason={
|
finished_reason={
|
||||||
"type": "abort",
|
"type": "abort",
|
||||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
"message": message,
|
"message": message,
|
||||||
},
|
},
|
||||||
|
rid=req_to_abort.rid,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return req_to_abort.rid == recv_req.rid
|
return req_to_abort.rid == recv_req.rid
|
||||||
@@ -2005,7 +2006,7 @@ class Scheduler(
|
|||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
for req in reqs_to_abort:
|
for req in reqs_to_abort:
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
AbortReq(req.rid, abort_reason=req.to_abort_message)
|
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -2575,7 +2576,7 @@ class Scheduler(
|
|||||||
if self.enable_hicache_storage:
|
if self.enable_hicache_storage:
|
||||||
# to release prefetch events associated with the request
|
# to release prefetch events associated with the request
|
||||||
self.tree_cache.release_aborted_request(req.rid)
|
self.tree_cache.release_aborted_request(req.rid)
|
||||||
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
||||||
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
@@ -2687,11 +2688,12 @@ class Scheduler(
|
|||||||
return SlowDownReqOutput()
|
return SlowDownReqOutput()
|
||||||
|
|
||||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
action = recv_req.action
|
||||||
|
if action == ExpertDistributionReqType.START_RECORD:
|
||||||
get_global_expert_distribution_recorder().start_record()
|
get_global_expert_distribution_recorder().start_record()
|
||||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
elif action == ExpertDistributionReqType.STOP_RECORD:
|
||||||
get_global_expert_distribution_recorder().stop_record()
|
get_global_expert_distribution_recorder().stop_record()
|
||||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
||||||
get_global_expert_distribution_recorder().dump_record()
|
get_global_expert_distribution_recorder().dump_record()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
||||||
@@ -2774,7 +2776,8 @@ class IdleSleeper:
|
|||||||
|
|
||||||
|
|
||||||
def is_health_check_generate_req(recv_req):
|
def is_health_check_generate_req(recv_req):
|
||||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
rid = getattr(recv_req, "rid", None)
|
||||||
|
return rid is not None and rid.startswith("HEALTH_CHECK")
|
||||||
|
|
||||||
|
|
||||||
def is_work_request(recv_req):
|
def is_work_request(recv_req):
|
||||||
|
|||||||
@@ -9,7 +9,11 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
|
BatchEmbeddingOutput,
|
||||||
|
BatchTokenIDOutput,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
||||||
)
|
)
|
||||||
self.abort_request(AbortReq(req.rid))
|
self.abort_request(AbortReq(rid=req.rid))
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
else:
|
else:
|
||||||
# being chunked reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
||||||
)
|
)
|
||||||
self.abort_request(AbortReq(req.rid))
|
self.abort_request(AbortReq(rid=req.rid))
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOutput(
|
||||||
rids,
|
|
||||||
finished_reasons,
|
finished_reasons,
|
||||||
decoded_texts,
|
decoded_texts,
|
||||||
decode_ids_list,
|
decode_ids_list,
|
||||||
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
output_token_ids_logprobs_val,
|
output_token_ids_logprobs_val,
|
||||||
output_token_ids_logprobs_idx,
|
output_token_ids_logprobs_idx,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
|
rids=rids,
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
)
|
)
|
||||||
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
|
|||||||
prompt_tokens.append(len(req.origin_input_ids))
|
prompt_tokens.append(len(req.origin_input_ids))
|
||||||
cached_tokens.append(req.cached_tokens)
|
cached_tokens.append(req.cached_tokens)
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchEmbeddingOut(
|
BatchEmbeddingOutput(
|
||||||
rids,
|
|
||||||
finished_reasons,
|
finished_reasons,
|
||||||
embeddings,
|
embeddings,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
|
rids=rids,
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
DestroyWeightsUpdateGroupReqOutput,
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
|
ExpertDistributionReqType,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
FlushCacheReqOutput,
|
FlushCacheReqOutput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
LoRAUpdateResult,
|
LoRAUpdateOutput,
|
||||||
MultiTokenizerWrapper,
|
MultiTokenizerWrapper,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.expert_distribution_communicator.handle_recv,
|
self.expert_distribution_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
LoRAUpdateResult,
|
LoRAUpdateOutput,
|
||||||
self.update_lora_adapter_communicator.handle_recv,
|
self.update_lora_adapter_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
|
|||||||
|
|
||||||
async def start_expert_distribution_record(self: TokenizerManager):
|
async def start_expert_distribution_record(self: TokenizerManager):
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD)
|
||||||
|
await self.expert_distribution_communicator(req)
|
||||||
|
|
||||||
async def stop_expert_distribution_record(self: TokenizerManager):
|
async def stop_expert_distribution_record(self: TokenizerManager):
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD)
|
||||||
|
await self.expert_distribution_communicator(req)
|
||||||
|
|
||||||
async def dump_expert_distribution_record(self: TokenizerManager):
|
async def dump_expert_distribution_record(self: TokenizerManager):
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD)
|
||||||
|
await self.expert_distribution_communicator(req)
|
||||||
|
|
||||||
async def init_weights_update_group(
|
async def init_weights_update_group(
|
||||||
self: TokenizerManager,
|
self: TokenizerManager,
|
||||||
|
|||||||
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_tokenizer_from_processor,
|
get_tokenizer_from_processor,
|
||||||
)
|
)
|
||||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
from sglang.srt.lora.lora_registry import LoRARegistry
|
||||||
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOutput,
|
||||||
BatchMultimodalOut,
|
BatchMultimodalOutput,
|
||||||
BatchStrOut,
|
BatchStrOutput,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOutput,
|
||||||
BatchTokenizedEmbeddingReqInput,
|
BatchTokenizedEmbeddingReqInput,
|
||||||
BatchTokenizedGenerateReqInput,
|
BatchTokenizedGenerateReqInput,
|
||||||
CloseSessionReqInput,
|
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetLoadReqInput,
|
GetLoadReqInput,
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
MultiTokenizerWrapper,
|
MultiTokenizerWrapper,
|
||||||
OpenSessionReqInput,
|
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
SessionParams,
|
SessionParams,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
BatchStrOut,
|
BatchStrOutput,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOutput,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOutput,
|
||||||
BatchMultimodalOut,
|
BatchMultimodalOutput,
|
||||||
),
|
),
|
||||||
self._handle_batch_output,
|
self._handle_batch_output,
|
||||||
),
|
),
|
||||||
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
obj.rid,
|
|
||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
mm_inputs,
|
mm_inputs,
|
||||||
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
obj.top_logprobs_num,
|
obj.top_logprobs_num,
|
||||||
obj.token_ids_logprob,
|
obj.token_ids_logprob,
|
||||||
obj.stream,
|
obj.stream,
|
||||||
|
rid=obj.rid,
|
||||||
bootstrap_host=obj.bootstrap_host,
|
bootstrap_host=obj.bootstrap_host,
|
||||||
bootstrap_port=obj.bootstrap_port,
|
bootstrap_port=obj.bootstrap_port,
|
||||||
bootstrap_room=obj.bootstrap_room,
|
bootstrap_room=obj.bootstrap_room,
|
||||||
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
obj.rid,
|
|
||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
mm_inputs,
|
mm_inputs,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
|
rid=obj.rid,
|
||||||
priority=obj.priority,
|
priority=obj.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
def abort_request(self, rid: str = "", abort_all: bool = False):
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||||
if not abort_all and rid not in self.rid_to_state:
|
if not abort_all and rid not in self.rid_to_state:
|
||||||
return
|
return
|
||||||
req = AbortReq(rid, abort_all)
|
req = AbortReq(rid=rid, abort_all=abort_all)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
# TODO: also use custom_labels from the request
|
# TODO: also use custom_labels from the request
|
||||||
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
def _handle_batch_output(
|
def _handle_batch_output(
|
||||||
self,
|
self,
|
||||||
recv_obj: Union[
|
recv_obj: Union[
|
||||||
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
BatchStrOutput,
|
||||||
|
BatchEmbeddingOutput,
|
||||||
|
BatchMultimodalOutput,
|
||||||
|
BatchTokenIDOutput,
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
if not isinstance(recv_obj, BatchEmbeddingOutput):
|
||||||
meta_info.update(
|
meta_info.update(
|
||||||
{
|
{
|
||||||
"completion_tokens": recv_obj.completion_tokens[i],
|
"completion_tokens": recv_obj.completion_tokens[i],
|
||||||
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
if getattr(recv_obj, "output_hidden_states", None):
|
if getattr(recv_obj, "output_hidden_states", None):
|
||||||
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
if isinstance(recv_obj, BatchStrOutput):
|
||||||
state.text += recv_obj.output_strs[i]
|
state.text += recv_obj.output_strs[i]
|
||||||
if state.obj.stream:
|
if state.obj.stream:
|
||||||
state.output_ids.extend(recv_obj.output_ids[i])
|
state.output_ids.extend(recv_obj.output_ids[i])
|
||||||
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"output_ids": output_token_ids,
|
"output_ids": output_token_ids,
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOutput):
|
||||||
if self.server_args.stream_output and state.obj.stream:
|
if self.server_args.stream_output and state.obj.stream:
|
||||||
state.output_ids.extend(recv_obj.output_ids[i])
|
state.output_ids.extend(recv_obj.output_ids[i])
|
||||||
output_token_ids = state.output_ids[state.last_output_offset :]
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
||||||
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"output_ids": output_token_ids,
|
"output_ids": output_token_ids,
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
elif isinstance(recv_obj, BatchMultimodalOutput):
|
||||||
raise NotImplementedError("BatchMultimodalOut not implemented")
|
raise NotImplementedError("BatchMultimodalOut not implemented")
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOutput)
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"embedding": recv_obj.embeddings[i],
|
"embedding": recv_obj.embeddings[i],
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
top_logprobs_num: int,
|
top_logprobs_num: int,
|
||||||
token_ids_logprob: List[int],
|
token_ids_logprob: List[int],
|
||||||
return_text_in_logprobs: bool,
|
return_text_in_logprobs: bool,
|
||||||
recv_obj: BatchStrOut,
|
recv_obj: BatchStrOutput,
|
||||||
recv_obj_index: int,
|
recv_obj_index: int,
|
||||||
):
|
):
|
||||||
if recv_obj.input_token_logprobs_val is None:
|
if recv_obj.input_token_logprobs_val is None:
|
||||||
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
ret.append(None)
|
ret.append(None)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
|
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
||||||
completion_tokens = (
|
completion_tokens = (
|
||||||
recv_obj.completion_tokens[i]
|
recv_obj.completion_tokens[i]
|
||||||
if getattr(recv_obj, "completion_tokens", None)
|
if getattr(recv_obj, "completion_tokens", None)
|
||||||
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
|
|
||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj: AbortReq):
|
||||||
if is_health_check_generate_req(recv_obj):
|
if is_health_check_generate_req(recv_obj):
|
||||||
return
|
return
|
||||||
state = self.rid_to_state[recv_obj.rid]
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
|
|||||||
Reference in New Issue
Block a user