[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)
This commit is contained in:
@@ -142,17 +142,6 @@ def get_tokenizer(
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if (
|
||||
"llama" in tokenizer_name.lower()
|
||||
and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
||||
):
|
||||
warnings.warn(
|
||||
"For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||
"take a long time. To reduce the initialization time, consider "
|
||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer."
|
||||
)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
|
||||
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -193,10 +193,7 @@ def start_controller_process(
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
configure_logger(server_args)
|
||||
|
||||
try:
|
||||
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
||||
|
||||
@@ -27,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
|
||||
launch_tp_servers,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +52,7 @@ class ControllerSingle:
|
||||
self.dp_worker_id = dp_worker_id
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
# Init communication
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if not self.is_dp_worker:
|
||||
@@ -133,11 +133,11 @@ def start_controller_process(
|
||||
queue: multiprocessing.connection.Connection = None,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
if is_data_parallel_worker:
|
||||
logger_prefix = f" DP{dp_worker_id} TP0"
|
||||
else:
|
||||
logger_prefix = " TP0"
|
||||
configure_logger(server_args, prefix=logger_prefix)
|
||||
|
||||
if not is_data_parallel_worker:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
|
||||
@@ -56,6 +56,7 @@ class DetokenizerManager:
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_router = context.socket(zmq.PULL)
|
||||
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
@@ -75,10 +76,13 @@ class DetokenizerManager:
|
||||
self.decode_status = {}
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||
recv_obj = await self.recv_from_router.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchEmbeddingOut):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
rids=recv_obj.rids,
|
||||
@@ -88,19 +92,18 @@ class DetokenizerManager:
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
elif isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
# If it is a weight update request, no detokenization is needed.
|
||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||
continue
|
||||
elif self.tokenizer is None:
|
||||
# If the tokenizer is skipped, no detokenization is needed
|
||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||
continue
|
||||
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
if self.tokenizer is None:
|
||||
# Send BatchTokenIDOut if no tokenizer init'ed.
|
||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||
continue
|
||||
|
||||
# Initialize decode status
|
||||
read_ids, surr_ids = [], []
|
||||
for i in range(bs):
|
||||
@@ -134,6 +137,7 @@ class DetokenizerManager:
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
)
|
||||
|
||||
# Incremental decoding
|
||||
output_strs = []
|
||||
for i in range(bs):
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
|
||||
@@ -21,7 +21,7 @@ import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
@@ -80,6 +80,7 @@ class TokenizerManager:
|
||||
):
|
||||
self.server_args = server_args
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
@@ -87,6 +88,7 @@ class TokenizerManager:
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
self.hf_config = get_config(
|
||||
@@ -104,6 +106,7 @@ class TokenizerManager:
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
# Create tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
@@ -127,6 +130,7 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
# Store states
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
@@ -134,63 +138,6 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
async def get_pixel_values(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
get_pixel_values,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||
):
|
||||
@@ -198,7 +145,7 @@ class TokenizerManager:
|
||||
self.create_handle_loop()
|
||||
|
||||
while self.model_update_lock.locked():
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
obj.post_init()
|
||||
is_single = obj.is_single
|
||||
@@ -214,8 +161,8 @@ class TokenizerManager:
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request,
|
||||
index=None,
|
||||
is_cache_for_prefill=False,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
not_use_index = index is None
|
||||
@@ -235,7 +182,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
return_logprob = (
|
||||
@@ -345,7 +292,7 @@ class TokenizerManager:
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common input
|
||||
# Send prefill requests to cache the common prefix
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
@@ -436,7 +383,6 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
# Then process the responses based on streaming option
|
||||
|
||||
is_stream = hasattr(obj, "stream") and obj.stream
|
||||
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
@@ -482,9 +428,9 @@ class TokenizerManager:
|
||||
|
||||
async def _get_pixel_values(self, image_data):
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
return await self.get_pixel_values(image_data[0])
|
||||
return await self._get_pixel_values_internal(image_data[0])
|
||||
elif isinstance(image_data, str):
|
||||
return await self.get_pixel_values(image_data)
|
||||
return await self._get_pixel_values_internal(image_data)
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
@@ -563,6 +509,13 @@ class TokenizerManager:
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
@@ -587,13 +540,6 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
@@ -617,6 +563,8 @@ class TokenizerManager:
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
||||
@@ -713,11 +661,69 @@ class TokenizerManager:
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
_process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return _process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
"""Init the global processor for multi modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
def _process_single_image_task(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
@@ -759,4 +765,4 @@ def get_pixel_values(
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -145,7 +146,6 @@ class ModelTpServer:
|
||||
|
||||
# Print info
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] "
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"max_running_requests={self.max_running_requests}, "
|
||||
@@ -284,7 +284,7 @@ class ModelTpServer:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Decode batch. "
|
||||
f"Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
@@ -443,7 +443,7 @@ class ModelTpServer:
|
||||
|
||||
if num_mixed_running > 0:
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Prefill batch"
|
||||
f"Prefill batch"
|
||||
f"(mixed #running-req: {num_mixed_running}). "
|
||||
f"#new-seq: {len(can_run_list)}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
@@ -453,7 +453,7 @@ class ModelTpServer:
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Prefill batch. "
|
||||
f"Prefill batch. "
|
||||
f"#new-seq: {len(can_run_list)}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
@@ -631,7 +631,7 @@ class ModelTpServer:
|
||||
self.new_token_ratio = new_token_ratio
|
||||
|
||||
logger.info(
|
||||
"decode out of memory happened, "
|
||||
"Decode out of memory happened. "
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
@@ -848,7 +848,9 @@ def run_tp_server(
|
||||
nccl_port: int,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Run a tensor parallel server."""
|
||||
"""Run a tensor parallel model server."""
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
|
||||
@@ -109,7 +109,7 @@ class ModelRunner:
|
||||
def init_torch_distributed(self):
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
|
||||
logger.info("Init nccl begin.")
|
||||
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
@@ -152,8 +152,7 @@ class ModelRunner:
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Load weight begin. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
@@ -208,7 +207,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Load weight end. "
|
||||
f"Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
@@ -224,7 +223,7 @@ class ModelRunner:
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Update weights begin. "
|
||||
f"Update weights begin. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
@@ -298,7 +297,7 @@ class ModelRunner:
|
||||
self.load_config = load_config
|
||||
self.model_config.path = model_path
|
||||
|
||||
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
@@ -387,7 +386,7 @@ class ModelRunner:
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Memory pool end. "
|
||||
f"Memory pool end. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
@@ -473,9 +472,7 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = None
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
|
||||
)
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
|
||||
if self.server_args.disable_cuda_graph_padding:
|
||||
batch_size_list = list(range(1, 32)) + [64, 128]
|
||||
|
||||
@@ -123,7 +123,7 @@ def create_streaming_error_response(
|
||||
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
||||
global chat_template_name
|
||||
|
||||
print(f"Use chat template: {chat_template_arg}")
|
||||
logger.info(f"Use chat template: {chat_template_arg}")
|
||||
if not chat_template_exists(chat_template_arg):
|
||||
if not os.path.exists(chat_template_arg):
|
||||
raise RuntimeError(
|
||||
@@ -355,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print("error in SGLang:", e)
|
||||
logger.error("error in SGLang:", e)
|
||||
# Update batch status to "failed"
|
||||
retrieve_batch = batch_storage[batch_id]
|
||||
retrieve_batch.status = "failed"
|
||||
|
||||
@@ -74,6 +74,7 @@ from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
enable_show_time_cost,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -270,15 +271,12 @@ def launch_server(
|
||||
"""Launch an HTTP server."""
|
||||
global tokenizer_manager
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
configure_logger(server_args)
|
||||
|
||||
server_args.check_server_args()
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
# Allocate ports
|
||||
# Allocate ports for inter-process communications
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
|
||||
@@ -418,7 +418,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-mixed-chunk",
|
||||
action="store_true",
|
||||
help="Enabling mixing prefill and decode in a chunked batch.",
|
||||
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-compile",
|
||||
|
||||
@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
|
||||
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
||||
|
||||
|
||||
def add_api_key_middleware(app, api_key):
|
||||
def add_api_key_middleware(app, api_key: str):
|
||||
@app.middleware("http")
|
||||
async def authentication(request, call_next):
|
||||
if request.method == "OPTIONS":
|
||||
@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def prepare_model(model_path):
|
||||
def prepare_model(model_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(model_path):
|
||||
from modelscope import snapshot_download
|
||||
@@ -713,7 +713,7 @@ def prepare_model(model_path):
|
||||
return model_path
|
||||
|
||||
|
||||
def prepare_tokenizer(tokenizer_path):
|
||||
def prepare_tokenizer(tokenizer_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(tokenizer_path):
|
||||
from modelscope import snapshot_download
|
||||
@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
|
||||
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
||||
)
|
||||
return tokenizer_path
|
||||
|
||||
|
||||
def configure_logger(server_args, prefix: str = ""):
|
||||
format = f"[%(asctime)s{prefix}] %(message)s"
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format=format,
|
||||
datefmt="%H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user