Simplify tokenizer manager (#2254)
This commit is contained in:
@@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from sglang.api import Engine
|
|
||||||
from sglang.bench_serving import (
|
from sglang.bench_serving import (
|
||||||
get_dataset,
|
get_dataset,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
sample_random_requests,
|
sample_random_requests,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.srt.server import Runtime, start_profile, stop_profile
|
from sglang.srt.server import Engine, Runtime
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
@@ -204,12 +203,12 @@ def throughput_test_once(
|
|||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
if profile:
|
if profile:
|
||||||
start_profile()
|
backend.start_profile()
|
||||||
|
|
||||||
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
stop_profile()
|
backend.stop_profile()
|
||||||
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
||||||
|
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
|
|||||||
@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_hashes": image_hashes,
|
"image_hashes": image_hashes,
|
||||||
"image_sizes": image_sizes,
|
"image_sizes": image_sizes,
|
||||||
"modalities": request_obj.modalities,
|
"modalities": request_obj.modalities or ["image"],
|
||||||
"image_grid_thws": image_grid_thws,
|
"image_grid_thws": image_grid_thws,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -376,16 +376,6 @@ class ProfileReq(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GetMemPoolSizeReq:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GetMemPoolSizeReqOutput:
|
|
||||||
size: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqInput:
|
class OpenSessionReqInput:
|
||||||
capacity_of_str_len: int
|
capacity_of_str_len: int
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GetMemPoolSizeReq,
|
|
||||||
GetMemPoolSizeReqOutput,
|
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -521,10 +519,6 @@ class Scheduler:
|
|||||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||||
elif isinstance(recv_req, CloseSessionReqInput):
|
elif isinstance(recv_req, CloseSessionReqInput):
|
||||||
self.close_session(recv_req)
|
self.close_session(recv_req)
|
||||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetMemPoolSizeReq,
|
|
||||||
GetMemPoolSizeReqOutput,
|
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -218,7 +216,7 @@ class TokenizerManager:
|
|||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||||
obj.image_data, input_text or input_ids, obj
|
obj.image_data, input_text or input_ids, obj
|
||||||
)
|
)
|
||||||
if image_inputs and "input_ids" in image_inputs:
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
@@ -406,25 +404,6 @@ class TokenizerManager:
|
|||||||
req = ProfileReq.STOP_PROFILE
|
req = ProfileReq.STOP_PROFILE
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
async def get_memory_pool_size(self):
|
|
||||||
if self.to_create_loop:
|
|
||||||
self.create_handle_loop()
|
|
||||||
|
|
||||||
req = GetMemPoolSizeReq()
|
|
||||||
|
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
|
||||||
self.mem_pool_size = asyncio.Future()
|
|
||||||
|
|
||||||
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
|
||||||
if self.server_args.dp_size == 1:
|
|
||||||
res = await self.mem_pool_size
|
|
||||||
return res.size
|
|
||||||
else: # self.server_args.dp_size > 1
|
|
||||||
self.mem_pool_size_tmp = []
|
|
||||||
res = await self.mem_pool_size
|
|
||||||
ret = [r.size for r in res]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
@@ -552,15 +531,6 @@ class TokenizerManager:
|
|||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
continue
|
continue
|
||||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
|
||||||
if self.server_args.dp_size == 1:
|
|
||||||
self.mem_pool_size.set_result(recv_obj)
|
|
||||||
else: # self.sever_args.dp_size > 1
|
|
||||||
self.mem_pool_size_tmp.append(recv_obj)
|
|
||||||
# set future if the all results are received
|
|
||||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
|
||||||
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
recv_obj.session_id
|
recv_obj.session_id
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import logging
|
|||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
# Fast API
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -105,7 +104,7 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_manager: TokenizerManager = None
|
tokenizer_manager: TokenizerManager = None
|
||||||
_max_total_num_tokens = None
|
scheduler_info: Dict = None
|
||||||
|
|
||||||
##### Native API endpoints #####
|
##### Native API endpoints #####
|
||||||
|
|
||||||
@@ -171,16 +170,6 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_profile():
|
|
||||||
"""Start profiling."""
|
|
||||||
tokenizer_manager.start_profile()
|
|
||||||
|
|
||||||
|
|
||||||
def stop_profile():
|
|
||||||
"""Stop profiling."""
|
|
||||||
tokenizer_manager.stop_profile()
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/start_profile")
|
@app.get("/start_profile")
|
||||||
@app.post("/start_profile")
|
@app.post("/start_profile")
|
||||||
async def start_profile_async():
|
async def start_profile_async():
|
||||||
@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
|
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
"""Handle a generate request."""
|
"""Handle a generate request."""
|
||||||
@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
app.post("/generate")(generate_request)
|
|
||||||
app.put("/generate")(generate_request)
|
|
||||||
|
|
||||||
|
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle an embedding request."""
|
"""Handle an embedding request."""
|
||||||
@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
app.post("/encode")(encode_request)
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
app.put("/encode")(encode_request)
|
|
||||||
|
|
||||||
|
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
app.post("/classify")(classify_request)
|
|
||||||
app.put("/classify")(classify_request)
|
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@@ -392,11 +372,11 @@ def launch_engine(
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
global _max_total_num_tokens
|
global scheduler_info
|
||||||
|
|
||||||
# Configure global environment
|
# Configure global environment
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
@@ -462,8 +442,8 @@ def launch_engine(
|
|||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
|
|
||||||
# Wait for model to finish loading & get max token nums
|
# Wait for model to finish loading
|
||||||
scheduler_info = []
|
scheduler_infos = []
|
||||||
for i in range(len(scheduler_pipe_readers)):
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
data = scheduler_pipe_readers[i].recv()
|
data = scheduler_pipe_readers[i].recv()
|
||||||
|
|
||||||
@@ -471,10 +451,10 @@ def launch_engine(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Initialization failed. Please see the error messages above."
|
"Initialization failed. Please see the error messages above."
|
||||||
)
|
)
|
||||||
scheduler_info.append(data)
|
scheduler_infos.append(data)
|
||||||
|
|
||||||
# Assume all schedulers have same max_total_num_tokens
|
# Assume all schedulers have same max_total_num_tokens
|
||||||
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
scheduler_info = scheduler_infos[0]
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
@@ -488,12 +468,12 @@ def launch_server(
|
|||||||
|
|
||||||
1. HTTP server: A FastAPI server that routes requests to the engine.
|
1. HTTP server: A FastAPI server that routes requests to the engine.
|
||||||
2. SRT engine:
|
2. SRT engine:
|
||||||
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
|
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||||
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
1. The HTTP server and Tokenizer Manager both run in the main process.
|
1. The HTTP server and TokenizerManager both run in the main process.
|
||||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||||
"""
|
"""
|
||||||
launch_engine(server_args=server_args)
|
launch_engine(server_args=server_args)
|
||||||
@@ -502,7 +482,7 @@ def launch_server(
|
|||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
add_api_key_middleware(app, server_args.api_key)
|
add_api_key_middleware(app, server_args.api_key)
|
||||||
|
|
||||||
# add prometheus middleware
|
# Add prometheus middleware
|
||||||
if server_args.enable_metrics:
|
if server_args.enable_metrics:
|
||||||
add_prometheus_middleware(app)
|
add_prometheus_middleware(app)
|
||||||
enable_func_timer()
|
enable_func_timer()
|
||||||
@@ -514,7 +494,7 @@ def launch_server(
|
|||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Listen for HTTP requests
|
# Update logging configs
|
||||||
LOGGING_CONFIG["formatters"]["default"][
|
LOGGING_CONFIG["formatters"]["default"][
|
||||||
"fmt"
|
"fmt"
|
||||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||||
@@ -523,6 +503,8 @@ def launch_server(
|
|||||||
"fmt"
|
"fmt"
|
||||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
# Listen for HTTP requests
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host=server_args.host,
|
host=server_args.host,
|
||||||
@@ -538,8 +520,7 @@ def launch_server(
|
|||||||
async def _get_server_info():
|
async def _get_server_info():
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||||
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
**scheduler_info,
|
||||||
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Debug print
|
||||||
# logger.info(f"{res.json()=}")
|
# logger.info(f"{res.json()=}")
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
@@ -821,18 +803,11 @@ class Engine:
|
|||||||
launching the HTTP server adds unnecessary complexity or overhead,
|
launching the HTTP server adds unnecessary complexity or overhead,
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||||
|
|
||||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||||
atexit.register(self.shutdown)
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
# runtime server default log level is log
|
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||||
# offline engine works in scripts, so we set it to error
|
|
||||||
|
|
||||||
if "log_level" not in kwargs:
|
|
||||||
kwargs["log_level"] = "error"
|
|
||||||
|
|
||||||
server_args = ServerArgs(*args, **kwargs)
|
|
||||||
launch_engine(server_args=server_args)
|
launch_engine(server_args=server_args)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -955,5 +930,11 @@ class Engine:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(encode_request(obj, None))
|
return loop.run_until_complete(encode_request(obj, None))
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
tokenizer_manager.start_profile()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
tokenizer_manager.stop_profile()
|
||||||
|
|
||||||
async def get_server_info(self):
|
async def get_server_info(self):
|
||||||
return await _get_server_info()
|
return await _get_server_info()
|
||||||
|
|||||||
@@ -220,9 +220,6 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
max_total_num_tokens = response_json["max_total_num_tokens"]
|
max_total_num_tokens = response_json["max_total_num_tokens"]
|
||||||
self.assertIsInstance(max_total_num_tokens, int)
|
self.assertIsInstance(max_total_num_tokens, int)
|
||||||
|
|
||||||
memory_pool_size = response_json["memory_pool_size"]
|
|
||||||
self.assertIsInstance(memory_pool_size, int)
|
|
||||||
|
|
||||||
attention_backend = response_json["attention_backend"]
|
attention_backend = response_json["attention_backend"]
|
||||||
self.assertIsInstance(attention_backend, str)
|
self.assertIsInstance(attention_backend, str)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user