Revert "Support Multi Process Tokenizer Manager" (#8960)
This commit is contained in:
@@ -18,18 +18,14 @@ This file implements HTTP APIs for the inference engine via fastapi.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import ctypes
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as multiprocessing
|
import multiprocessing as multiprocessing
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from multiprocessing import Lock, Manager, Value, shared_memory
|
|
||||||
from typing import AsyncIterator, Callable, Dict, Optional
|
from typing import AsyncIterator, Callable, Dict, Optional
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
@@ -98,7 +94,7 @@ from sglang.srt.managers.template_manager import TemplateManager
|
|||||||
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
add_api_key_middleware,
|
add_api_key_middleware,
|
||||||
add_prometheus_middleware,
|
add_prometheus_middleware,
|
||||||
@@ -133,165 +129,8 @@ def set_global_state(global_state: _GlobalState):
|
|||||||
_global_state = global_state
|
_global_state = global_state
|
||||||
|
|
||||||
|
|
||||||
def serialize_port_args(port_args: PortArgs) -> dict:
|
|
||||||
"""Serialize PortArgs into a shareable dictionary"""
|
|
||||||
return {
|
|
||||||
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
|
||||||
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
|
||||||
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
|
||||||
"nccl_port": port_args.nccl_port,
|
|
||||||
"rpc_ipc_name": port_args.rpc_ipc_name,
|
|
||||||
"metrics_ipc_name": port_args.metrics_ipc_name,
|
|
||||||
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_port_args(data: dict) -> PortArgs:
|
|
||||||
"""Deserialize PortArgs from a shared dictionary"""
|
|
||||||
return PortArgs(**data)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_server_args(server_args: ServerArgs) -> dict:
|
|
||||||
"""Serialize ServerArgs into a shareable dictionary"""
|
|
||||||
return dataclasses.asdict(server_args)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_server_args(data: dict) -> ServerArgs:
|
|
||||||
"""Deserialize ServerArgs from a shared dictionary"""
|
|
||||||
return ServerArgs(**data)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
|
||||||
"""Serialize scheduler_info into a shareable dictionary"""
|
|
||||||
return scheduler_info
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_scheduler_info(data: dict) -> Dict:
|
|
||||||
"""Deserialize scheduler_info from a shared dictionary"""
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
|
||||||
"""Write data to shared memory"""
|
|
||||||
serialized = json.dumps(data).encode("utf-8")
|
|
||||||
size = len(serialized)
|
|
||||||
try:
|
|
||||||
# Try to open existing shared memory
|
|
||||||
shm = shared_memory.SharedMemory(name=name)
|
|
||||||
# If size is insufficient, close and recreate
|
|
||||||
if shm.size < size:
|
|
||||||
shm.close()
|
|
||||||
shm.unlink()
|
|
||||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
|
||||||
except FileNotFoundError:
|
|
||||||
# If not present, create new shared memory
|
|
||||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
|
||||||
|
|
||||||
shm.buf[:size] = serialized
|
|
||||||
return shm
|
|
||||||
|
|
||||||
|
|
||||||
def read_from_shared_memory(name: str) -> dict:
|
|
||||||
"""Read data from shared memory"""
|
|
||||||
try:
|
|
||||||
shm = shared_memory.SharedMemory(name=name)
|
|
||||||
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
|
||||||
shm.close()
|
|
||||||
return data
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise FileNotFoundError(f"Shared memory {name} not found")
|
|
||||||
|
|
||||||
|
|
||||||
def get_main_process_id() -> int:
|
|
||||||
"""Get the main process ID"""
|
|
||||||
return multiprocessing.current_process()._parent_pid
|
|
||||||
|
|
||||||
|
|
||||||
def write_data_for_multi_tokenizer(
|
|
||||||
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
|
||||||
):
|
|
||||||
"""Write args information to share memory for multi-tokenizer"""
|
|
||||||
# get main process ID
|
|
||||||
main_pid = get_main_process_id()
|
|
||||||
current_pid = os.getpid()
|
|
||||||
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
|
||||||
|
|
||||||
# Write port_args to shared memory
|
|
||||||
port_args_shm = write_to_shared_memory(
|
|
||||||
serialize_port_args(port_args), f"port_args_{current_pid}"
|
|
||||||
)
|
|
||||||
# Write server_args to shared memory
|
|
||||||
server_args_shm = write_to_shared_memory(
|
|
||||||
serialize_server_args(server_args), f"server_args_{current_pid}"
|
|
||||||
)
|
|
||||||
# Write scheduler_info to shared memory
|
|
||||||
scheduler_info_shm = write_to_shared_memory(
|
|
||||||
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
|
||||||
)
|
|
||||||
|
|
||||||
port_args_shm.close()
|
|
||||||
server_args_shm.close()
|
|
||||||
scheduler_info_shm.close()
|
|
||||||
|
|
||||||
return port_args_shm, server_args_shm, scheduler_info_shm
|
|
||||||
|
|
||||||
|
|
||||||
def init_multi_tokenizer() -> ServerArgs:
|
|
||||||
"""Read args information from shm and init tokenizer manager for current process"""
|
|
||||||
pid = os.getpid()
|
|
||||||
main_pid = get_main_process_id()
|
|
||||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
|
||||||
|
|
||||||
# Read port_args, server_args, and scheduler_info from shared memory
|
|
||||||
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
|
|
||||||
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
|
|
||||||
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
|
|
||||||
port_args = deserialize_port_args(port_args_data)
|
|
||||||
server_args = deserialize_server_args(server_args_data)
|
|
||||||
scheduler_info = deserialize_scheduler_info(scheduler_info_data)
|
|
||||||
|
|
||||||
port_args.tokenizer_ipc_name = (
|
|
||||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch tokenizer process
|
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, False)
|
|
||||||
template_manager = TemplateManager()
|
|
||||||
template_manager.initialize_templates(
|
|
||||||
tokenizer_manager=tokenizer_manager,
|
|
||||||
model_path=server_args.model_path,
|
|
||||||
chat_template=server_args.chat_template,
|
|
||||||
completion_template=server_args.completion_template,
|
|
||||||
)
|
|
||||||
# register multi tokenizer
|
|
||||||
tokenizer_manager.register_to_main_tokenizer_manager()
|
|
||||||
|
|
||||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
|
||||||
set_global_state(
|
|
||||||
_GlobalState(
|
|
||||||
tokenizer_manager=tokenizer_manager,
|
|
||||||
template_manager=template_manager,
|
|
||||||
scheduler_info=scheduler_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return server_args
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(fast_api_app: FastAPI):
|
async def lifespan(fast_api_app: FastAPI):
|
||||||
server_args = getattr(fast_api_app, "server_args", None)
|
|
||||||
if server_args is None:
|
|
||||||
# for multi-tokenizer
|
|
||||||
fast_api_app.server_args = init_multi_tokenizer()
|
|
||||||
fast_api_app.warmup_thread = threading.Thread(
|
|
||||||
target=_wait_and_warmup,
|
|
||||||
args=(
|
|
||||||
fast_api_app.server_args,
|
|
||||||
None, # pipe_finish_writer not needed in worker
|
|
||||||
None, # launch_callback not needed in worker
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize OpenAI serving handlers
|
# Initialize OpenAI serving handlers
|
||||||
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
_global_state.tokenizer_manager, _global_state.template_manager
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
@@ -352,15 +191,7 @@ async def lifespan(fast_api_app: FastAPI):
|
|||||||
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||||
if warmup_thread is not None:
|
if warmup_thread is not None:
|
||||||
warmup_thread.start()
|
warmup_thread.start()
|
||||||
|
yield
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if server_args.tokenizer_worker_num > 1:
|
|
||||||
pid = os.getpid()
|
|
||||||
logger.info(f"uvicorn worker {pid} ending...")
|
|
||||||
warmup_thread.join()
|
|
||||||
logger.info(f"uvicorn {pid} ended")
|
|
||||||
|
|
||||||
|
|
||||||
# Fast API
|
# Fast API
|
||||||
@@ -377,30 +208,6 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Function to setup all middlewares for multi-process compatibility
|
|
||||||
def setup_middlewares():
|
|
||||||
"""Setup all middlewares for both single and multi-process modes"""
|
|
||||||
worker_pid = os.getpid()
|
|
||||||
|
|
||||||
# Setup API key middleware
|
|
||||||
api_key = os.environ.get("SGLANG_API_KEY", "")
|
|
||||||
if api_key:
|
|
||||||
add_api_key_middleware(app, api_key)
|
|
||||||
logger.info(f"Worker {worker_pid} added API key middleware")
|
|
||||||
|
|
||||||
# Setup prometheus middleware
|
|
||||||
# Check if metrics are enabled via environment variable
|
|
||||||
enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false")
|
|
||||||
if enable_metrics:
|
|
||||||
add_prometheus_middleware(app)
|
|
||||||
enable_func_timer()
|
|
||||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
|
||||||
|
|
||||||
|
|
||||||
# Call setup function at module level for multi-process compatibility
|
|
||||||
setup_middlewares()
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(HTTPException)
|
@app.exception_handler(HTTPException)
|
||||||
async def validation_exception_handler(request: Request, exc: HTTPException):
|
async def validation_exception_handler(request: Request, exc: HTTPException):
|
||||||
"""Enrich HTTP exception with status code and other details"""
|
"""Enrich HTTP exception with status code and other details"""
|
||||||
@@ -1186,19 +993,9 @@ def launch_server(
|
|||||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||||
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
||||||
"""
|
"""
|
||||||
if server_args.tokenizer_worker_num > 1:
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||||
port_args = PortArgs.init_new(server_args)
|
server_args=server_args
|
||||||
port_args.tokenizer_worker_ipc_name = (
|
)
|
||||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
|
||||||
)
|
|
||||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
|
||||||
server_args=server_args, port_args=port_args
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
|
||||||
server_args=server_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_global_state(
|
set_global_state(
|
||||||
_GlobalState(
|
_GlobalState(
|
||||||
tokenizer_manager=tokenizer_manager,
|
tokenizer_manager=tokenizer_manager,
|
||||||
@@ -1207,83 +1004,42 @@ def launch_server(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.tokenizer_worker_num > 1:
|
# Add api key authorization
|
||||||
# Set environment variables for middlewares in main process
|
if server_args.api_key:
|
||||||
if server_args.api_key:
|
add_api_key_middleware(app, server_args.api_key)
|
||||||
os.environ["SGLANG_API_KEY"] = server_args.api_key
|
|
||||||
logger.info("Main process set SGLANG_API_KEY")
|
|
||||||
|
|
||||||
if server_args.enable_metrics:
|
# Add prometheus middleware
|
||||||
os.environ["SGLANG_ENABLE_METRICS"] = "true"
|
if server_args.enable_metrics:
|
||||||
logger.info("Main process set SGLANG_ENABLE_METRICS=true")
|
add_prometheus_middleware(app)
|
||||||
|
enable_func_timer()
|
||||||
|
|
||||||
port_args_shm, server_args_shm, scheduler_info_shm = (
|
# Send a warmup request - we will create the thread launch it
|
||||||
write_data_for_multi_tokenizer(
|
# in the lifespan after all other warmups have fired.
|
||||||
port_args,
|
warmup_thread = threading.Thread(
|
||||||
server_args,
|
target=_wait_and_warmup,
|
||||||
scheduler_info,
|
args=(
|
||||||
)
|
server_args,
|
||||||
)
|
pipe_finish_writer,
|
||||||
else:
|
launch_callback,
|
||||||
# Add api key authorization
|
),
|
||||||
if server_args.api_key:
|
)
|
||||||
add_api_key_middleware(app, server_args.api_key)
|
app.warmup_thread = warmup_thread
|
||||||
|
|
||||||
# Add prometheus middleware
|
|
||||||
if server_args.enable_metrics:
|
|
||||||
add_prometheus_middleware(app)
|
|
||||||
enable_func_timer()
|
|
||||||
|
|
||||||
# Send a warmup request - we will create the thread launch it
|
|
||||||
# in the lifespan after all other warmups have fired.
|
|
||||||
warmup_thread = threading.Thread(
|
|
||||||
target=_wait_and_warmup,
|
|
||||||
args=(
|
|
||||||
server_args,
|
|
||||||
pipe_finish_writer,
|
|
||||||
launch_callback,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
app.warmup_thread = warmup_thread
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update logging configs
|
# Update logging configs
|
||||||
set_uvicorn_logging_configs()
|
set_uvicorn_logging_configs()
|
||||||
app.server_args = server_args
|
app.server_args = server_args
|
||||||
# Listen for HTTP requests
|
# Listen for HTTP requests
|
||||||
if server_args.tokenizer_worker_num > 1:
|
uvicorn.run(
|
||||||
from uvicorn.config import LOGGING_CONFIG
|
app,
|
||||||
|
host=server_args.host,
|
||||||
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
|
port=server_args.port,
|
||||||
"handlers": ["default"],
|
log_level=server_args.log_level_http or server_args.log_level,
|
||||||
"level": "INFO",
|
timeout_keep_alive=5,
|
||||||
"propagate": False,
|
loop="uvloop",
|
||||||
}
|
)
|
||||||
uvicorn.run(
|
|
||||||
"sglang.srt.entrypoints.http_server:app",
|
|
||||||
host=server_args.host,
|
|
||||||
port=server_args.port,
|
|
||||||
log_level=server_args.log_level_http or server_args.log_level,
|
|
||||||
timeout_keep_alive=5,
|
|
||||||
loop="uvloop",
|
|
||||||
workers=server_args.tokenizer_worker_num,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=server_args.host,
|
|
||||||
port=server_args.port,
|
|
||||||
log_level=server_args.log_level_http or server_args.log_level,
|
|
||||||
timeout_keep_alive=5,
|
|
||||||
loop="uvloop",
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
if server_args.tokenizer_worker_num > 1:
|
warmup_thread.join()
|
||||||
port_args_shm.unlink()
|
|
||||||
server_args_shm.unlink()
|
|
||||||
scheduler_info_shm.unlink()
|
|
||||||
else:
|
|
||||||
warmup_thread.join()
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_server_warmup(
|
def _execute_server_warmup(
|
||||||
|
|||||||
@@ -31,12 +31,10 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchMultimodalOut,
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
MultiTokenizerRegisterReq,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_logger,
|
configure_logger,
|
||||||
get_workerids_from_rids,
|
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
)
|
)
|
||||||
@@ -83,6 +81,7 @@ class DetokenizerManager:
|
|||||||
self.send_to_tokenizer = get_zmq_socket(
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
else:
|
else:
|
||||||
@@ -95,208 +94,21 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||||
self.is_dummy = server_args.load_format == "dummy"
|
self.is_dummy = server_args.load_format == "dummy"
|
||||||
self.tokenizer_worker_num = server_args.tokenizer_worker_num
|
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||||
(MultiTokenizerRegisterReq, lambda x: None),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
output = self._request_dispatcher(recv_obj)
|
||||||
output = self._request_dispatcher(recv_obj)
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
if self.tokenizer_worker_num <= 1:
|
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
|
||||||
else:
|
|
||||||
# Extract worker_id from rid
|
|
||||||
if isinstance(recv_obj.rids, list):
|
|
||||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"tokenizer_worker_num > 1, recv_obj.rids must be list"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(self, "tokenizer_mapping"):
|
|
||||||
self.tokenizer_mapping = {}
|
|
||||||
|
|
||||||
# Create ZMQ context if needed
|
|
||||||
if not hasattr(self, "_zmq_context"):
|
|
||||||
self._zmq_context = zmq.Context()
|
|
||||||
|
|
||||||
# Send data using the corresponding socket
|
|
||||||
for i, worker_id in enumerate(worker_ids):
|
|
||||||
if worker_id not in self.tokenizer_mapping:
|
|
||||||
# register the worker if not already done
|
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
|
||||||
self.init_tokenizer_mapping(recv_obj, worker_id)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
|
||||||
"Please ensure the worker is registered correctly."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create a new output object based on the type
|
|
||||||
if isinstance(output, BatchEmbeddingOut):
|
|
||||||
new_output = BatchEmbeddingOut(
|
|
||||||
rids=[output.rids[i]],
|
|
||||||
finished_reasons=[output.finished_reasons[i]],
|
|
||||||
embeddings=[output.embeddings[i]],
|
|
||||||
prompt_tokens=[output.prompt_tokens[i]],
|
|
||||||
cached_tokens=[output.cached_tokens[i]],
|
|
||||||
)
|
|
||||||
elif isinstance(output, BatchStrOut):
|
|
||||||
new_output = BatchStrOut(
|
|
||||||
rids=[output.rids[i]],
|
|
||||||
finished_reasons=(
|
|
||||||
[output.finished_reasons[i]]
|
|
||||||
if len(output.finished_reasons) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_strs=(
|
|
||||||
[output.output_strs[i]]
|
|
||||||
if len(output.output_strs) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_ids=(
|
|
||||||
[output.output_ids[i]]
|
|
||||||
if output.output_ids and len(output.output_ids) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
prompt_tokens=(
|
|
||||||
[output.prompt_tokens[i]]
|
|
||||||
if len(output.prompt_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
completion_tokens=(
|
|
||||||
[output.completion_tokens[i]]
|
|
||||||
if len(output.completion_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
cached_tokens=(
|
|
||||||
[output.cached_tokens[i]]
|
|
||||||
if len(output.cached_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
spec_verify_ct=(
|
|
||||||
[output.spec_verify_ct[i]]
|
|
||||||
if len(output.spec_verify_ct) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_token_logprobs_val=(
|
|
||||||
[output.input_token_logprobs_val[i]]
|
|
||||||
if output.input_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_token_logprobs_idx=(
|
|
||||||
[output.input_token_logprobs_idx[i]]
|
|
||||||
if output.input_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_token_logprobs_val=(
|
|
||||||
[output.output_token_logprobs_val[i]]
|
|
||||||
if output.output_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_token_logprobs_idx=(
|
|
||||||
[output.output_token_logprobs_idx[i]]
|
|
||||||
if output.output_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_top_logprobs_val=(
|
|
||||||
[output.input_top_logprobs_val[i]]
|
|
||||||
if output.input_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_top_logprobs_idx=(
|
|
||||||
[output.input_top_logprobs_idx[i]]
|
|
||||||
if output.input_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_top_logprobs_val=(
|
|
||||||
[output.output_top_logprobs_val[i]]
|
|
||||||
if output.output_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_top_logprobs_idx=(
|
|
||||||
[output.output_top_logprobs_idx[i]]
|
|
||||||
if output.output_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_token_ids_logprobs_val=(
|
|
||||||
[output.input_token_ids_logprobs_val[i]]
|
|
||||||
if output.input_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
input_token_ids_logprobs_idx=(
|
|
||||||
[output.input_token_ids_logprobs_idx[i]]
|
|
||||||
if output.input_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_token_ids_logprobs_val=(
|
|
||||||
[output.output_token_ids_logprobs_val[i]]
|
|
||||||
if output.output_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_token_ids_logprobs_idx=(
|
|
||||||
[output.output_token_ids_logprobs_idx[i]]
|
|
||||||
if output.output_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
output_hidden_states=(
|
|
||||||
[output.output_hidden_states[i]]
|
|
||||||
if output.output_hidden_states
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(output, BatchMultimodalOut):
|
|
||||||
new_output = BatchMultimodalOut(
|
|
||||||
rids=[output.rids[i]],
|
|
||||||
finished_reasons=[output.finished_reasons[i]],
|
|
||||||
prompt_tokens=[output.prompt_tokens[i]],
|
|
||||||
completion_tokens=[output.completion_tokens[i]],
|
|
||||||
cached_tokens=[output.cached_tokens[i]],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_output = output
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
|
||||||
except zmq.error.ZMQError as e:
|
|
||||||
logger.info(
|
|
||||||
f"ZMQ error when sending to worker {worker_id}: {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in detokenizer event loop: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def init_tokenizer_mapping(
|
|
||||||
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
|
||||||
):
|
|
||||||
"""init tokenizer mapping from register request"""
|
|
||||||
ipc_name = recv_obj.ipc_name
|
|
||||||
worker_id_int = int(worker_id)
|
|
||||||
|
|
||||||
if worker_id_int not in self.tokenizer_mapping:
|
|
||||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
|
||||||
self.tokenizer_mapping[worker_id_int] = socket
|
|
||||||
logger.info(
|
|
||||||
f"Detokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
|
|
||||||
)
|
|
||||||
|
|
||||||
def trim_matched_stop(
|
def trim_matched_stop(
|
||||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||||
|
|||||||
@@ -782,13 +782,12 @@ class BatchEmbeddingOut:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReqInput:
|
class FlushCacheReqInput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReqOutput:
|
class FlushCacheReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -799,7 +798,6 @@ class UpdateWeightFromDiskReqInput:
|
|||||||
load_format: Optional[str] = None
|
load_format: Optional[str] = None
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -808,7 +806,6 @@ class UpdateWeightFromDiskReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
# Number of paused requests during weight sync.
|
# Number of paused requests during weight sync.
|
||||||
num_paused_requests: Optional[int] = 0
|
num_paused_requests: Optional[int] = 0
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -822,14 +819,12 @@ class UpdateWeightsFromDistributedReqInput:
|
|||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromDistributedReqOutput:
|
class UpdateWeightsFromDistributedReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -847,14 +842,12 @@ class UpdateWeightsFromTensorReqInput:
|
|||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromTensorReqOutput:
|
class UpdateWeightsFromTensorReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -871,27 +864,23 @@ class InitWeightsUpdateGroupReqInput:
|
|||||||
group_name: str = "weight_update_group"
|
group_name: str = "weight_update_group"
|
||||||
# The backend
|
# The backend
|
||||||
backend: str = "nccl"
|
backend: str = "nccl"
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsUpdateGroupReqOutput:
|
class InitWeightsUpdateGroupReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqInput:
|
class GetWeightsByNameReqInput:
|
||||||
name: str
|
name: str
|
||||||
truncate_size: int = 100
|
truncate_size: int = 100
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqOutput:
|
class GetWeightsByNameReqOutput:
|
||||||
parameter: list
|
parameter: list
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -899,12 +888,11 @@ class ReleaseMemoryOccupationReqInput:
|
|||||||
# Optional tags to identify the memory region, which is primarily used for RL
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
# Currently we only support `weights` and `kv_cache`
|
# Currently we only support `weights` and `kv_cache`
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReleaseMemoryOccupationReqOutput:
|
class ReleaseMemoryOccupationReqOutput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -912,23 +900,21 @@ class ResumeMemoryOccupationReqInput:
|
|||||||
# Optional tags to identify the memory region, which is primarily used for RL
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
# Currently we only support `weights` and `kv_cache`
|
# Currently we only support `weights` and `kv_cache`
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResumeMemoryOccupationReqOutput:
|
class ResumeMemoryOccupationReqOutput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlowDownReqInput:
|
class SlowDownReqInput:
|
||||||
forward_sleep_time: Optional[float]
|
forward_sleep_time: Optional[float]
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlowDownReqOutput:
|
class SlowDownReqOutput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -937,37 +923,29 @@ class AbortReq:
|
|||||||
rid: str = ""
|
rid: str = ""
|
||||||
# Whether to abort all requests
|
# Whether to abort all requests
|
||||||
abort_all: bool = False
|
abort_all: bool = False
|
||||||
|
# The finished reason data
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
finished_reason: Optional[Dict[str, Any]] = None
|
finished_reason: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.rids = self.rid
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetInternalStateReq:
|
class GetInternalStateReq:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetInternalStateReqOutput:
|
class GetInternalStateReqOutput:
|
||||||
internal_state: Dict[Any, Any]
|
internal_state: Dict[Any, Any]
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetInternalStateReq:
|
class SetInternalStateReq:
|
||||||
server_args: Dict[str, Any]
|
server_args: Dict[str, Any]
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetInternalStateReqOutput:
|
class SetInternalStateReqOutput:
|
||||||
updated: bool
|
updated: bool
|
||||||
server_args: Dict[str, Any]
|
server_args: Dict[str, Any]
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -983,7 +961,6 @@ class ProfileReqInput:
|
|||||||
profile_by_stage: bool = False
|
profile_by_stage: bool = False
|
||||||
with_stack: Optional[bool] = None
|
with_stack: Optional[bool] = None
|
||||||
record_shapes: Optional[bool] = None
|
record_shapes: Optional[bool] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ProfileReqType(Enum):
|
class ProfileReqType(Enum):
|
||||||
@@ -1002,14 +979,12 @@ class ProfileReq:
|
|||||||
with_stack: Optional[bool] = None
|
with_stack: Optional[bool] = None
|
||||||
record_shapes: Optional[bool] = None
|
record_shapes: Optional[bool] = None
|
||||||
profile_id: Optional[str] = None
|
profile_id: Optional[str] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReqOutput:
|
class ProfileReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1018,32 +993,27 @@ class ConfigureLoggingReq:
|
|||||||
log_requests_level: Optional[int] = None
|
log_requests_level: Optional[int] = None
|
||||||
dump_requests_folder: Optional[str] = None
|
dump_requests_folder: Optional[str] = None
|
||||||
dump_requests_threshold: Optional[int] = None
|
dump_requests_threshold: Optional[int] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqInput:
|
class OpenSessionReqInput:
|
||||||
capacity_of_str_len: int
|
capacity_of_str_len: int
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CloseSessionReqInput:
|
class CloseSessionReqInput:
|
||||||
session_id: str
|
session_id: str
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqOutput:
|
class OpenSessionReqOutput:
|
||||||
session_id: Optional[str]
|
session_id: Optional[str]
|
||||||
success: bool
|
success: bool
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HealthCheckOutput:
|
class HealthCheckOutput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -1055,7 +1025,7 @@ class ExpertDistributionReq(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpertDistributionReqOutput:
|
class ExpertDistributionReqOutput:
|
||||||
rids: Optional[Union[List[str], str]] = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1080,21 +1050,18 @@ class ParseFunctionCallReq:
|
|||||||
tool_call_parser: Optional[str] = (
|
tool_call_parser: Optional[str] = (
|
||||||
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||||
)
|
)
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SeparateReasoningReqInput:
|
class SeparateReasoningReqInput:
|
||||||
text: str # The text to parse.
|
text: str # The text to parse.
|
||||||
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VertexGenerateReqInput:
|
class VertexGenerateReqInput:
|
||||||
instances: List[dict]
|
instances: List[dict]
|
||||||
parameters: Optional[dict] = None
|
parameters: Optional[dict] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1119,7 +1086,6 @@ class LoadLoRAAdapterReqInput:
|
|||||||
pinned: bool = False
|
pinned: bool = False
|
||||||
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||||
lora_id: Optional[str] = None
|
lora_id: Optional[str] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
def to_ref(self) -> LoRARef:
|
def to_ref(self) -> LoRARef:
|
||||||
return LoRARef(
|
return LoRARef(
|
||||||
@@ -1136,7 +1102,6 @@ class UnloadLoRAAdapterReqInput:
|
|||||||
lora_name: str
|
lora_name: str
|
||||||
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||||
lora_id: Optional[str] = None
|
lora_id: Optional[str] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
def to_ref(self) -> LoRARef:
|
def to_ref(self) -> LoRARef:
|
||||||
return LoRARef(
|
return LoRARef(
|
||||||
@@ -1150,18 +1115,11 @@ class LoRAUpdateResult:
|
|||||||
success: bool
|
success: bool
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MultiTokenizerRegisterReq:
|
|
||||||
rids: Optional[Union[List[str], str]] = None
|
|
||||||
ipc_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BlockReqType(Enum):
|
class BlockReqType(Enum):
|
||||||
BLOCK = 1
|
BLOCK = 1
|
||||||
UNBLOCK = 2
|
UNBLOCK = 2
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
MultiTokenizerRegisterReq,
|
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -248,6 +247,7 @@ class Scheduler(
|
|||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
self.idle_sleeper = None
|
self.idle_sleeper = None
|
||||||
|
|
||||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||||
@@ -522,7 +522,6 @@ class Scheduler(
|
|||||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1065,8 +1064,6 @@ class Scheduler(
|
|||||||
if self.recv_from_rpc is not None:
|
if self.recv_from_rpc is not None:
|
||||||
self.recv_from_rpc.send_pyobj(output)
|
self.recv_from_rpc.send_pyobj(output)
|
||||||
else:
|
else:
|
||||||
if recv_req.rids is not None:
|
|
||||||
output.rids = recv_req.rids
|
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
@@ -2407,10 +2404,6 @@ class Scheduler(
|
|||||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
|
||||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
|
||||||
return recv_req
|
|
||||||
|
|
||||||
def slow_down(self, recv_req: SlowDownReqInput):
|
def slow_down(self, recv_req: SlowDownReqInput):
|
||||||
t = recv_req.forward_sleep_time
|
t = recv_req.forward_sleep_time
|
||||||
if t is not None and t <= 0:
|
if t is not None and t <= 0:
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
LoRAUpdateResult,
|
LoRAUpdateResult,
|
||||||
MultiTokenizerRegisterReq,
|
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -125,8 +124,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
dataclass_to_string_truncated,
|
dataclass_to_string_truncated,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_origin_rid,
|
|
||||||
get_workerids_from_rids,
|
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
)
|
)
|
||||||
@@ -174,9 +171,6 @@ class ReqState:
|
|||||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
_global_tokenizer_worker_num = 1
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
|
|
||||||
@@ -184,7 +178,6 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
is_main: Optional[bool] = True,
|
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
@@ -198,9 +191,6 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
self.crash_dump_folder = server_args.crash_dump_folder
|
self.crash_dump_folder = server_args.crash_dump_folder
|
||||||
|
|
||||||
self.is_main = is_main
|
|
||||||
self.worker_id = os.getpid()
|
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
self.served_model_name = server_args.served_model_name
|
self.served_model_name = server_args.served_model_name
|
||||||
@@ -265,41 +255,13 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(3)
|
context = zmq.asyncio.Context(2)
|
||||||
self.recv_from_detokenizer = get_zmq_socket(
|
self.recv_from_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||||
)
|
)
|
||||||
global _global_tokenizer_worker_num
|
self.send_to_scheduler = get_zmq_socket(
|
||||||
_global_tokenizer_worker_num = server_args.tokenizer_worker_num
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||||
if server_args.tokenizer_worker_num > 1:
|
)
|
||||||
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
|
||||||
if self.is_main:
|
|
||||||
self.send_to_scheduler = get_zmq_socket(
|
|
||||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
||||||
)
|
|
||||||
self.receive_from_worker = get_zmq_socket(
|
|
||||||
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
|
||||||
)
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
self._task = asyncio.run_coroutine_threadsafe(
|
|
||||||
self.router_worker_obj(), self._loop
|
|
||||||
)
|
|
||||||
# Start handle_loop simultaneously
|
|
||||||
self._handle_task = asyncio.run_coroutine_threadsafe(
|
|
||||||
print_exception_wrapper(self.handle_loop), self._loop
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# actual send to main receiver_from_worker
|
|
||||||
self.send_to_scheduler = get_zmq_socket(
|
|
||||||
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.send_to_scheduler = get_zmq_socket(
|
|
||||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request states
|
# Request states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
@@ -353,27 +315,26 @@ class TokenizerManager:
|
|||||||
# Start kv boostrap server on prefill
|
# Start kv boostrap server on prefill
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# only start bootstrap server on prefill tm
|
# only start bootstrap server on prefill tm
|
||||||
if self.is_main:
|
kv_bootstrap_server_class = get_kv_class(
|
||||||
kv_bootstrap_server_class = get_kv_class(
|
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
)
|
||||||
)
|
self.bootstrap_server = kv_bootstrap_server_class(
|
||||||
self.bootstrap_server = kv_bootstrap_server_class(
|
self.server_args.disaggregation_bootstrap_port
|
||||||
self.server_args.disaggregation_bootstrap_port
|
)
|
||||||
)
|
is_create_store = (
|
||||||
is_create_store = (
|
self.server_args.node_rank == 0
|
||||||
self.server_args.node_rank == 0
|
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
)
|
||||||
)
|
if is_create_store:
|
||||||
if is_create_store:
|
try:
|
||||||
try:
|
from mf_adapter import create_config_store
|
||||||
from mf_adapter import create_config_store
|
|
||||||
|
|
||||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||||
create_config_store(ascend_url)
|
create_config_store(ascend_url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Failed create mf store, invalid ascend_url."
|
error_message = f"Failed create mf store, invalid ascend_url."
|
||||||
error_message += f" With exception {e}"
|
error_message += f" With exception {e}"
|
||||||
raise error_message
|
raise error_message
|
||||||
|
|
||||||
# For load balancing
|
# For load balancing
|
||||||
self.current_load = 0
|
self.current_load = 0
|
||||||
@@ -506,14 +467,6 @@ class TokenizerManager:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_loop(self):
|
|
||||||
self._loop.run_forever()
|
|
||||||
|
|
||||||
async def router_worker_obj(self):
|
|
||||||
while True:
|
|
||||||
recv_obj = await self.receive_from_worker.recv_pyobj()
|
|
||||||
await self.send_to_scheduler.send_pyobj(recv_obj)
|
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -526,15 +479,6 @@ class TokenizerManager:
|
|||||||
async with self._is_updating_cond:
|
async with self._is_updating_cond:
|
||||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||||
|
|
||||||
if self.server_args.tokenizer_worker_num > 1:
|
|
||||||
# Modify rid, add worker_id
|
|
||||||
if isinstance(obj.rid, list):
|
|
||||||
# If it's an array, add worker_id prefix to each element
|
|
||||||
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
|
||||||
else:
|
|
||||||
# If it's a single value, add worker_id prefix
|
|
||||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1561,377 +1505,11 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
# In multi-worker mode, distribute results to corresponding workers
|
|
||||||
if self.server_args.tokenizer_worker_num > 1 and self.is_main:
|
|
||||||
await self._distribute_result_to_workers(recv_obj)
|
|
||||||
else:
|
|
||||||
# In single worker mode, process directly
|
|
||||||
self._result_dispatcher(recv_obj)
|
|
||||||
|
|
||||||
self.last_receive_tstamp = time.time()
|
|
||||||
|
|
||||||
def init_tokenizer_mapping(self, recv_obj: MultiTokenizerRegisterReq):
|
|
||||||
"""init tokenizer mapping from register request"""
|
|
||||||
if isinstance(recv_obj.rids, list):
|
|
||||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"tokenizer_worker_num > 1, recv_obj.rids must be list")
|
|
||||||
|
|
||||||
for worker_id in worker_ids:
|
|
||||||
ipc_name = recv_obj.ipc_name
|
|
||||||
worker_id_int = int(worker_id)
|
|
||||||
|
|
||||||
if worker_id_int not in self.tokenizer_mapping:
|
|
||||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
|
||||||
self.tokenizer_mapping[worker_id_int] = socket
|
|
||||||
logger.info(
|
|
||||||
f"Main Tokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _distribute_result_to_workers(self, recv_obj):
|
|
||||||
"""Distribute result to corresponding workers based on rid"""
|
|
||||||
|
|
||||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
|
||||||
if len(worker_ids) == 0:
|
|
||||||
self._result_dispatcher(recv_obj)
|
self._result_dispatcher(recv_obj)
|
||||||
return
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
if not hasattr(self, "tokenizer_mapping"):
|
|
||||||
self.tokenizer_mapping = {}
|
|
||||||
|
|
||||||
# Create ZMQ context if needed
|
|
||||||
if not hasattr(self, "_zmq_context"):
|
|
||||||
self._zmq_context = zmq.Context()
|
|
||||||
|
|
||||||
# Distribute result to each worker
|
|
||||||
for i, worker_id in enumerate(worker_ids):
|
|
||||||
if worker_id not in self.tokenizer_mapping:
|
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
|
||||||
self.init_tokenizer_mapping(recv_obj)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
|
||||||
"Please ensure the worker is registered correctly."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(
|
|
||||||
recv_obj,
|
|
||||||
(
|
|
||||||
BatchStrOut,
|
|
||||||
BatchEmbeddingOut,
|
|
||||||
BatchTokenIDOut,
|
|
||||||
BatchMultimodalOut,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
# Send to worker
|
|
||||||
self.tokenizer_mapping[worker_id].send_pyobj(recv_obj)
|
|
||||||
else:
|
|
||||||
if isinstance(recv_obj, BatchTokenIDOut):
|
|
||||||
new_recv_obj = BatchTokenIDOut(
|
|
||||||
[recv_obj.rids[i]],
|
|
||||||
(
|
|
||||||
[recv_obj.finished_reasons[i]]
|
|
||||||
if len(recv_obj.finished_reasons) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.decoded_texts[i]]
|
|
||||||
if len(recv_obj.decoded_texts) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.decode_ids[i]]
|
|
||||||
if len(recv_obj.decode_ids) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.read_offsets[i]]
|
|
||||||
if len(recv_obj.read_offsets) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_ids[i]]
|
|
||||||
if recv_obj.output_ids and len(recv_obj.output_ids) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.skip_special_tokens[i]]
|
|
||||||
if len(recv_obj.skip_special_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.spaces_between_special_tokens[i]]
|
|
||||||
if len(recv_obj.spaces_between_special_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.no_stop_trim[i]]
|
|
||||||
if len(recv_obj.no_stop_trim) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.prompt_tokens[i]]
|
|
||||||
if len(recv_obj.prompt_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.completion_tokens[i]]
|
|
||||||
if len(recv_obj.completion_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.cached_tokens[i]]
|
|
||||||
if len(recv_obj.cached_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.spec_verify_ct[i]]
|
|
||||||
if len(recv_obj.spec_verify_ct) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_logprobs_val[i]]
|
|
||||||
if recv_obj.input_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_logprobs_val[i]]
|
|
||||||
if recv_obj.output_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_top_logprobs_val[i]]
|
|
||||||
if recv_obj.input_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_top_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_top_logprobs_val[i]]
|
|
||||||
if recv_obj.output_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_top_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_ids_logprobs_val[i]]
|
|
||||||
if recv_obj.input_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_ids_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_ids_logprobs_val[i]]
|
|
||||||
if recv_obj.output_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_ids_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_hidden_states[i]]
|
|
||||||
if recv_obj.output_hidden_states
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
|
||||||
new_recv_obj = BatchEmbeddingOut(
|
|
||||||
[recv_obj.rids[i]],
|
|
||||||
(
|
|
||||||
[recv_obj.finished_reasons[i]]
|
|
||||||
if len(recv_obj.finished_reasons) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.embeddings[i]]
|
|
||||||
if len(recv_obj.embeddings) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.prompt_tokens[i]]
|
|
||||||
if len(recv_obj.prompt_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.cached_tokens[i]]
|
|
||||||
if len(recv_obj.cached_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(recv_obj, BatchStrOut):
|
|
||||||
new_recv_obj = BatchStrOut(
|
|
||||||
[recv_obj.rids[i]],
|
|
||||||
(
|
|
||||||
[recv_obj.finished_reasons[i]]
|
|
||||||
if len(recv_obj.finished_reasons) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_strs[i]]
|
|
||||||
if len(recv_obj.output_strs) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_ids[i]]
|
|
||||||
if recv_obj.output_ids and len(recv_obj.output_ids) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.prompt_tokens[i]]
|
|
||||||
if len(recv_obj.prompt_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.completion_tokens[i]]
|
|
||||||
if len(recv_obj.completion_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.cached_tokens[i]]
|
|
||||||
if len(recv_obj.cached_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.spec_verify_ct[i]]
|
|
||||||
if len(recv_obj.spec_verify_ct) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_logprobs_val[i]]
|
|
||||||
if recv_obj.input_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_logprobs_val[i]]
|
|
||||||
if recv_obj.output_token_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_token_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_top_logprobs_val[i]]
|
|
||||||
if recv_obj.input_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_top_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_top_logprobs_val[i]]
|
|
||||||
if recv_obj.output_top_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_top_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_top_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_ids_logprobs_val[i]]
|
|
||||||
if recv_obj.input_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.input_token_ids_logprobs_idx[i]]
|
|
||||||
if recv_obj.input_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_ids_logprobs_val[i]]
|
|
||||||
if recv_obj.output_token_ids_logprobs_val
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_token_ids_logprobs_idx[i]]
|
|
||||||
if recv_obj.output_token_ids_logprobs_idx
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.output_hidden_states[i]]
|
|
||||||
if recv_obj.output_hidden_states
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
|
||||||
new_recv_obj = BatchMultimodalOut(
|
|
||||||
[recv_obj.rids[i]],
|
|
||||||
(
|
|
||||||
[recv_obj.finished_reasons[i]]
|
|
||||||
if len(recv_obj.finished_reasons) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
([recv_obj.outputs[i]] if len(recv_obj.outputs) > i else None),
|
|
||||||
(
|
|
||||||
[recv_obj.prompt_tokens[i]]
|
|
||||||
if len(recv_obj.prompt_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.completion_tokens[i]]
|
|
||||||
if len(recv_obj.completion_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[recv_obj.cached_tokens[i]]
|
|
||||||
if len(recv_obj.cached_tokens) > i
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
|
||||||
except zmq.ZMQError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to send result to worker {worker_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def register_to_main_tokenizer_manager(self):
|
|
||||||
"""Register this worker to the main TokenizerManager"""
|
|
||||||
req = MultiTokenizerRegisterReq()
|
|
||||||
req.rids = [f"{self.worker_id}_registertokenizer"]
|
|
||||||
req.ipc_name = self.tokenizer_ipc_name
|
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
def _handle_batch_output(
|
def _handle_batch_output(
|
||||||
self,
|
self,
|
||||||
@@ -1946,12 +1524,10 @@ class TokenizerManager:
|
|||||||
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
originRid = rid
|
|
||||||
if self.server_args.tokenizer_worker_num > 1:
|
|
||||||
originRid = get_origin_rid(rid)
|
|
||||||
# Build meta_info and return value
|
# Build meta_info and return value
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"id": originRid,
|
"id": rid,
|
||||||
"finish_reason": recv_obj.finished_reasons[i],
|
"finish_reason": recv_obj.finished_reasons[i],
|
||||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||||
}
|
}
|
||||||
@@ -2252,9 +1828,6 @@ class TokenizerManager:
|
|||||||
if is_health_check_generate_req(recv_obj):
|
if is_health_check_generate_req(recv_obj):
|
||||||
return
|
return
|
||||||
state = self.rid_to_state[recv_obj.rid]
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
rid = recv_obj.rid
|
|
||||||
if self.server_args.tokenizer_worker_num > 1:
|
|
||||||
rid = get_origin_rid(rid)
|
|
||||||
state.finished = True
|
state.finished = True
|
||||||
if recv_obj.finished_reason:
|
if recv_obj.finished_reason:
|
||||||
out = {
|
out = {
|
||||||
@@ -2267,7 +1840,7 @@ class TokenizerManager:
|
|||||||
out = {
|
out = {
|
||||||
"text": "",
|
"text": "",
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
"id": rid,
|
"id": recv_obj.rid,
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"type": "abort",
|
"type": "abort",
|
||||||
"message": "Abort before prefill",
|
"message": "Abort before prefill",
|
||||||
@@ -2456,7 +2029,6 @@ class _Communicator(Generic[T]):
|
|||||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||||
|
|
||||||
async def __call__(self, obj):
|
async def __call__(self, obj):
|
||||||
global _global_tokenizer_worker_num
|
|
||||||
ready_event = asyncio.Event()
|
ready_event = asyncio.Event()
|
||||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||||
self._ready_queue.append(ready_event)
|
self._ready_queue.append(ready_event)
|
||||||
@@ -2465,14 +2037,6 @@ class _Communicator(Generic[T]):
|
|||||||
assert self._result_values is None
|
assert self._result_values is None
|
||||||
|
|
||||||
if obj:
|
if obj:
|
||||||
if _global_tokenizer_worker_num > 1:
|
|
||||||
if obj.rids is None:
|
|
||||||
obj.rids = f"{os.getpid()}_{uuid.uuid4().hex}_Communicator"
|
|
||||||
else:
|
|
||||||
if isinstance(obj.rids, str):
|
|
||||||
obj.rids = f"{os.getpid()}_{obj.rids}"
|
|
||||||
elif isinstance(obj.rids, list):
|
|
||||||
obj.rids = [f"{os.getpid()}_{rid}" for rid in obj.rids]
|
|
||||||
self._sender.send_pyobj(obj)
|
self._sender.send_pyobj(obj)
|
||||||
|
|
||||||
self._result_event = asyncio.Event()
|
self._result_event = asyncio.Event()
|
||||||
@@ -2487,19 +2051,6 @@ class _Communicator(Generic[T]):
|
|||||||
return result_values
|
return result_values
|
||||||
|
|
||||||
def handle_recv(self, recv_obj: T):
|
def handle_recv(self, recv_obj: T):
|
||||||
global _global_tokenizer_worker_num
|
|
||||||
if _global_tokenizer_worker_num > 1:
|
|
||||||
# If rids is a string and not empty, remove the prefix
|
|
||||||
if (
|
|
||||||
hasattr(recv_obj, "rids")
|
|
||||||
and isinstance(recv_obj.rids, str)
|
|
||||||
and recv_obj.rids
|
|
||||||
):
|
|
||||||
recv_obj.rids = get_origin_rid(recv_obj.rids)
|
|
||||||
# If rids is a list, remove prefix from each element
|
|
||||||
elif hasattr(recv_obj, "rids") and isinstance(recv_obj.rids, list):
|
|
||||||
recv_obj.rids = [get_origin_rid(rid) for rid in recv_obj.rids]
|
|
||||||
|
|
||||||
self._result_values.append(recv_obj)
|
self._result_values.append(recv_obj)
|
||||||
if len(self._result_values) == self._fan_out:
|
if len(self._result_values) == self._fan_out:
|
||||||
self._result_event.set()
|
self._result_event.set()
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ class ServerArgs:
|
|||||||
model_path: str
|
model_path: str
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
tokenizer_worker_num: int = 1
|
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
model_loader_extra_config: str = "{}"
|
model_loader_extra_config: str = "{}"
|
||||||
@@ -732,12 +731,6 @@ class ServerArgs:
|
|||||||
default=ServerArgs.tokenizer_path,
|
default=ServerArgs.tokenizer_path,
|
||||||
help="The path of the tokenizer.",
|
help="The path of the tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--tokenizer-worker-num",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.tokenizer_worker_num,
|
|
||||||
help="The worker num of the tokenizer manager.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-mode",
|
"--tokenizer-mode",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2096,9 +2089,6 @@ class ServerArgs:
|
|||||||
self.chunked_prefill_size % self.page_size == 0
|
self.chunked_prefill_size % self.page_size == 0
|
||||||
), "chunked_prefill_size must be divisible by page_size"
|
), "chunked_prefill_size must be divisible by page_size"
|
||||||
|
|
||||||
# Check multi tokenizer
|
|
||||||
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
|
||||||
|
|
||||||
def check_lora_server_args(self):
|
def check_lora_server_args(self):
|
||||||
assert (
|
assert (
|
||||||
self.max_loras_per_batch > 0
|
self.max_loras_per_batch > 0
|
||||||
@@ -2264,9 +2254,6 @@ class PortArgs:
|
|||||||
# The ipc filename for Scheduler to send metrics
|
# The ipc filename for Scheduler to send metrics
|
||||||
metrics_ipc_name: str
|
metrics_ipc_name: str
|
||||||
|
|
||||||
# The ipc filename for Tokenizer and worker tokenizer
|
|
||||||
tokenizer_worker_ipc_name: Optional[str]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||||
if server_args.nccl_port is None:
|
if server_args.nccl_port is None:
|
||||||
@@ -2290,7 +2277,6 @@ class PortArgs:
|
|||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||||
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||||
tokenizer_worker_ipc_name=None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||||
@@ -2324,7 +2310,6 @@ class PortArgs:
|
|||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
|
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
|
||||||
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
|
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
|
||||||
tokenizer_worker_ipc_name=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2754,20 +2754,6 @@ def lru_cache_frozenset(maxsize=128):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def get_workerids_from_rids(rids):
|
|
||||||
if isinstance(rids, list):
|
|
||||||
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
|
||||||
elif isinstance(rids, str):
|
|
||||||
worker_ids = [int(rids.split("_")[0])]
|
|
||||||
else:
|
|
||||||
worker_ids = []
|
|
||||||
return worker_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_origin_rid(rid):
|
|
||||||
return rid.split("_", 1)[1] if "_" in rid else rid
|
|
||||||
|
|
||||||
|
|
||||||
def apply_module_patch(target_module, target_function, wrappers):
|
def apply_module_patch(target_module, target_function, wrappers):
|
||||||
original_module, original_function = parse_module_path(
|
original_module, original_function = parse_module_path(
|
||||||
target_module, target_function, False
|
target_module, target_function, False
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ suites = {
|
|||||||
TestFile("test_mla_int8_deepseek_v3.py", 429),
|
TestFile("test_mla_int8_deepseek_v3.py", 429),
|
||||||
TestFile("test_mla_flashinfer.py", 302),
|
TestFile("test_mla_flashinfer.py", 302),
|
||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_multi_tokenizer.py", 200),
|
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 234),
|
TestFile("test_no_overlap_scheduler.py", 234),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
|
|||||||
@@ -1,100 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import unittest
|
|
||||||
from dataclasses import fields, is_dataclass
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import sglang.srt.managers.io_struct as io_struct
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.run_eval import run_eval
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
CustomTestCase,
|
|
||||||
auto_config_device,
|
|
||||||
get_benchmark_args,
|
|
||||||
is_in_ci,
|
|
||||||
popen_launch_server,
|
|
||||||
run_benchmark,
|
|
||||||
write_github_step_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultiTokenizer(CustomTestCase):
|
|
||||||
# from test_hicache.py
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=[
|
|
||||||
"--tokenizer-worker-num",
|
|
||||||
8,
|
|
||||||
"--mem-fraction-static",
|
|
||||||
0.7,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=64,
|
|
||||||
num_threads=32,
|
|
||||||
)
|
|
||||||
metrics = run_eval(args)
|
|
||||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
|
||||||
|
|
||||||
def test_all_io_struct(self):
|
|
||||||
print("check all req types in io_struct.py")
|
|
||||||
result = []
|
|
||||||
for name, obj in inspect.getmembers(io_struct):
|
|
||||||
if inspect.isclass(obj) and is_dataclass(obj):
|
|
||||||
field_names = [f.name for f in fields(obj)]
|
|
||||||
if "rids" in field_names or "rid" in field_names:
|
|
||||||
continue
|
|
||||||
result.append(name)
|
|
||||||
print(f"WARNING:Some Request types in io_struct.py have no rids: {result}")
|
|
||||||
print(
|
|
||||||
"If a special request type can't work, check the rids field which is needed for multi-tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multi_tokenizer_ttft(self):
|
|
||||||
# from test_bench_serving.py run_bench_serving
|
|
||||||
args = get_benchmark_args(
|
|
||||||
base_url=self.base_url,
|
|
||||||
dataset_name="random",
|
|
||||||
dataset_path="",
|
|
||||||
tokenizer=None,
|
|
||||||
num_prompts=100,
|
|
||||||
random_input_len=4096,
|
|
||||||
random_output_len=2048,
|
|
||||||
sharegpt_context_len=None,
|
|
||||||
request_rate=1,
|
|
||||||
disable_stream=False,
|
|
||||||
disable_ignore_eos=False,
|
|
||||||
seed=0,
|
|
||||||
device=auto_config_device(),
|
|
||||||
lora_name=None,
|
|
||||||
)
|
|
||||||
res = run_benchmark(args)
|
|
||||||
if is_in_ci():
|
|
||||||
write_github_step_summary(
|
|
||||||
f"### test_multi_tokenizer_ttft\n"
|
|
||||||
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
|
|
||||||
)
|
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 11000)
|
|
||||||
self.assertLess(res["median_ttft_ms"], 86)
|
|
||||||
self.assertLess(res["median_itl_ms"], 10)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user