Minor improvements of TokenizerManager / health check (#6327)
This commit is contained in:
6
.github/workflows/pr-test-amd.yml
vendored
6
.github/workflows/pr-test-amd.yml
vendored
@@ -4,14 +4,16 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/sglang/**"
|
- "python/**"
|
||||||
|
- "scripts/**"
|
||||||
- "test/**"
|
- "test/**"
|
||||||
- "sgl-kernel/**"
|
- "sgl-kernel/**"
|
||||||
- ".github/workflows/pr-test-amd.yml"
|
- ".github/workflows/pr-test-amd.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "python/sglang/**"
|
- "python/**"
|
||||||
|
- "scripts/**"
|
||||||
- "test/**"
|
- "test/**"
|
||||||
- "sgl-kernel/**"
|
- "sgl-kernel/**"
|
||||||
- ".github/workflows/pr-test-amd.yml"
|
- ".github/workflows/pr-test-amd.yml"
|
||||||
|
|||||||
@@ -96,12 +96,14 @@ anthropic = ["anthropic>=0.20.0"]
|
|||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
|
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
|
||||||
test = [
|
test = [
|
||||||
|
"accelerate",
|
||||||
|
"torchaudio",
|
||||||
"jsonlines",
|
"jsonlines",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"pandas",
|
"pandas",
|
||||||
"sentence_transformers",
|
|
||||||
"accelerate",
|
|
||||||
"peft",
|
"peft",
|
||||||
|
"timm",
|
||||||
|
"sentence_transformers",
|
||||||
]
|
]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]", "sglang[torch_memory_saver]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]", "sglang[torch_memory_saver]"]
|
||||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from sglang.srt.utils import get_ip
|
from sglang.srt.utils import get_ip
|
||||||
|
|
||||||
|
FakeBootstrapHost = "2.2.2.2"
|
||||||
|
|
||||||
|
|
||||||
class DisaggregationMode(Enum):
|
class DisaggregationMode(Enum):
|
||||||
NULL = "null"
|
NULL = "null"
|
||||||
@@ -20,9 +22,6 @@ class DisaggregationMode(Enum):
|
|||||||
DECODE = "decode"
|
DECODE = "decode"
|
||||||
|
|
||||||
|
|
||||||
FakeBootstrapHost = "2.2.2.2"
|
|
||||||
|
|
||||||
|
|
||||||
def poll_and_all_reduce(pollers, gloo_group):
|
def poll_and_all_reduce(pollers, gloo_group):
|
||||||
polls = [int(poller.poll()) for poller in pollers]
|
polls = [int(poller.poll()) for poller in pollers]
|
||||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
|
_global_state.tokenizer_manager.health_check_failed = False
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -202,6 +203,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
f"last_heartbeat time: {last_receive_time}"
|
f"last_heartbeat time: {last_receive_time}"
|
||||||
)
|
)
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
|
_global_state.tokenizer_manager.health_check_failed = True
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Dict, Optional, Type, Union
|
||||||
|
|
||||||
import transformers
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
|||||||
@@ -129,7 +129,6 @@ from sglang.srt.utils import (
|
|||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
crash_on_warnings,
|
|
||||||
disable_request_logging,
|
disable_request_logging,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
SessionParams,
|
SessionParams,
|
||||||
|
SetInternalStateReq,
|
||||||
|
SetInternalStateReqOutput,
|
||||||
SlowDownReqInput,
|
SlowDownReqInput,
|
||||||
SlowDownReqOutput,
|
SlowDownReqOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
@@ -169,6 +172,11 @@ class TokenizerManager:
|
|||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
self.log_requests = server_args.log_requests
|
self.log_requests = server_args.log_requests
|
||||||
self.log_requests_level = server_args.log_requests_level
|
self.log_requests_level = server_args.log_requests_level
|
||||||
|
self.preferred_sampling_params = (
|
||||||
|
json.loads(server_args.preferred_sampling_params)
|
||||||
|
if server_args.preferred_sampling_params
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
@@ -228,6 +236,7 @@ class TokenizerManager:
|
|||||||
# Store states
|
# Store states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
self.health_check_failed = False
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
@@ -255,6 +264,10 @@ class TokenizerManager:
|
|||||||
"model_name": self.server_args.served_model_name,
|
"model_name": self.server_args.served_model_name,
|
||||||
# TODO: Add lora name/path in the future,
|
# TODO: Add lora name/path in the future,
|
||||||
},
|
},
|
||||||
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
||||||
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
||||||
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
||||||
|
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Communicators
|
# Communicators
|
||||||
@@ -285,9 +298,13 @@ class TokenizerManager:
|
|||||||
self.start_profile_communicator = _Communicator(
|
self.start_profile_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||||
self.get_internal_state_communicator = _Communicator(
|
self.get_internal_state_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.set_internal_state_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.expert_distribution_communicator = _Communicator(
|
self.expert_distribution_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -349,6 +366,10 @@ class TokenizerManager:
|
|||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
self.get_internal_state_communicator.handle_recv,
|
self.get_internal_state_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
SetInternalStateReqOutput,
|
||||||
|
self.set_internal_state_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
self.expert_distribution_communicator.handle_recv,
|
self.expert_distribution_communicator.handle_recv,
|
||||||
@@ -508,7 +529,14 @@ class TokenizerManager:
|
|||||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
# Parse sampling parameters
|
||||||
|
# Note: if there are preferred sampling params, we use them if they are not
|
||||||
|
# explicitly passed in sampling_params
|
||||||
|
if self.preferred_sampling_params:
|
||||||
|
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
|
||||||
|
else:
|
||||||
|
sampling_kwargs = obj.sampling_params
|
||||||
|
sampling_params = SamplingParams(**sampling_kwargs)
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
|
|
||||||
@@ -667,7 +695,6 @@ class TokenizerManager:
|
|||||||
|
|
||||||
generators = []
|
generators = []
|
||||||
rids = []
|
rids = []
|
||||||
|
|
||||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||||
if self.server_args.enable_tokenizer_batch_encode:
|
if self.server_args.enable_tokenizer_batch_encode:
|
||||||
# Validate batch tokenization constraints
|
# Validate batch tokenization constraints
|
||||||
@@ -857,7 +884,7 @@ class TokenizerManager:
|
|||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1
|
||||||
), "dp_size must be for update weights from distributed"
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
|
|
||||||
# This means that weight sync
|
# This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
@@ -946,6 +973,14 @@ class TokenizerManager:
|
|||||||
# Many DP ranks
|
# Many DP ranks
|
||||||
return [res.internal_state for res in responses]
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
|
async def set_internal_state(
|
||||||
|
self, obj: SetInternalStateReq
|
||||||
|
) -> SetInternalStateReqOutput:
|
||||||
|
responses: List[SetInternalStateReqOutput] = (
|
||||||
|
await self.set_internal_state_communicator(obj)
|
||||||
|
)
|
||||||
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
def get_log_request_metadata(self):
|
def get_log_request_metadata(self):
|
||||||
max_length = None
|
max_length = None
|
||||||
skip_names = None
|
skip_names = None
|
||||||
@@ -1015,11 +1050,17 @@ class TokenizerManager:
|
|||||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.event_loop = loop
|
||||||
|
|
||||||
# We cannot add signal handler when the tokenizer manager is not in
|
# We cannot add signal handler when the tokenizer manager is not in
|
||||||
# the main thread due to the CPython limitation.
|
# the main thread due to the CPython limitation.
|
||||||
if threading.current_thread() is threading.main_thread():
|
if threading.current_thread() is threading.main_thread():
|
||||||
signal_handler = SignalHandler(self)
|
signal_handler = SignalHandler(self)
|
||||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||||
|
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||||
|
loop.add_signal_handler(
|
||||||
|
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Signal handler is not added because the tokenizer manager is "
|
"Signal handler is not added because the tokenizer manager is "
|
||||||
@@ -1037,6 +1078,15 @@ class TokenizerManager:
|
|||||||
# Drain requests
|
# Drain requests
|
||||||
while True:
|
while True:
|
||||||
remain_num_req = len(self.rid_to_state)
|
remain_num_req = len(self.rid_to_state)
|
||||||
|
|
||||||
|
if self.health_check_failed:
|
||||||
|
# if health check failed, we should exit immediately
|
||||||
|
logger.error(
|
||||||
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||||
|
remain_num_req,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||||
)
|
)
|
||||||
@@ -1120,7 +1170,16 @@ class TokenizerManager:
|
|||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||||
raise NotImplementedError()
|
if isinstance(recv_obj.outputs[i], str):
|
||||||
|
out_dict = {
|
||||||
|
"text": recv_obj.outputs[i],
|
||||||
|
"meta_info": meta_info,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
out_dict = {
|
||||||
|
"outputs": json.dumps(recv_obj.outputs[i]),
|
||||||
|
"meta_info": meta_info,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
out_dict = {
|
out_dict = {
|
||||||
@@ -1366,12 +1425,18 @@ class SignalHandler:
|
|||||||
def __init__(self, tokenizer_manager: TokenizerManager):
|
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
|
||||||
def signal_handler(self, signum=None, frame=None):
|
def sigterm_handler(self, signum=None, frame=None):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||||
)
|
)
|
||||||
self.tokenizer_manager.gracefully_exit = True
|
self.tokenizer_manager.gracefully_exit = True
|
||||||
|
|
||||||
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||||
|
logger.error(
|
||||||
|
"Received sigquit from a child process. It usually means the child failed."
|
||||||
|
)
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ class ServerArgs:
|
|||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
enable_tokenizer_batch_encode: bool = False
|
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
@@ -59,6 +58,7 @@ class ServerArgs:
|
|||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
completion_template: Optional[str] = None
|
completion_template: Optional[str] = None
|
||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
|
enable_multimodal: Optional[bool] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
|
||||||
# Port for the HTTP server
|
# Port for the HTTP server
|
||||||
@@ -97,6 +97,10 @@ class ServerArgs:
|
|||||||
log_requests_level: int = 0
|
log_requests_level: int = 0
|
||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
enable_metrics: bool = False
|
enable_metrics: bool = False
|
||||||
|
bucket_time_to_first_token: Optional[List[float]] = None
|
||||||
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
||||||
|
bucket_inter_token_latency: Optional[List[float]] = None
|
||||||
|
collect_tokens_histogram: bool = False
|
||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
enable_request_time_stats_logging: bool = False
|
enable_request_time_stats_logging: bool = False
|
||||||
|
|
||||||
@@ -120,6 +124,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Model override args in JSON
|
# Model override args in JSON
|
||||||
json_model_override_args: str = "{}"
|
json_model_override_args: str = "{}"
|
||||||
|
preferred_sampling_params: Optional[str] = None
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_paths: Optional[List[str]] = None
|
||||||
@@ -154,9 +159,9 @@ class ServerArgs:
|
|||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
disable_cuda_graph_padding: bool = False
|
disable_cuda_graph_padding: bool = False
|
||||||
enable_nccl_nvls: bool = False
|
enable_nccl_nvls: bool = False
|
||||||
|
enable_tokenizer_batch_encode: bool = False
|
||||||
disable_outlines_disk_cache: bool = False
|
disable_outlines_disk_cache: bool = False
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
enable_multimodal: Optional[bool] = None
|
|
||||||
disable_overlap_schedule: bool = False
|
disable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
@@ -474,11 +479,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-tokenizer-batch-encode",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -603,6 +603,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use a CausalLM as an embedding model.",
|
help="Whether to use a CausalLM as an embedding model.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-multimodal",
|
||||||
|
default=ServerArgs.enable_multimodal,
|
||||||
|
action="store_true",
|
||||||
|
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--revision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -780,6 +786,33 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable log prometheus metrics.",
|
help="Enable log prometheus metrics.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bucket-time-to-first-token",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=ServerArgs.bucket_time_to_first_token,
|
||||||
|
help="The buckets of time to first token, specified as a list of floats.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bucket-inter-token-latency",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=ServerArgs.bucket_inter_token_latency,
|
||||||
|
help="The buckets of inter-token latency, specified as a list of floats.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bucket-e2e-request-latency",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=ServerArgs.bucket_e2e_request_latency,
|
||||||
|
help="The buckets of end-to-end request latency, specified as a list of floats.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--collect-tokens-histogram",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.collect_tokens_histogram,
|
||||||
|
help="Collect prompt/generation tokens histogram.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-log-interval",
|
"--decode-log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -868,6 +901,11 @@ class ServerArgs:
|
|||||||
help="A dictionary in JSON string format used to override default model configurations.",
|
help="A dictionary in JSON string format used to override default model configurations.",
|
||||||
default=ServerArgs.json_model_override_args,
|
default=ServerArgs.json_model_override_args,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preferred-sampling-params",
|
||||||
|
type=str,
|
||||||
|
help="json-formatted sampling settings that will be returned in /get_model_info",
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1043,6 +1081,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-tokenizer-batch-encode",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-outlines-disk-cache",
|
"--disable-outlines-disk-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -1053,12 +1096,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-multimodal",
|
|
||||||
default=ServerArgs.enable_multimodal,
|
|
||||||
action="store_true",
|
|
||||||
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-overlap-schedule",
|
"--disable-overlap-schedule",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Install the dependency in CI.
|
# Install the dependency in CI.
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
|
|
||||||
|
# Kill existing processes
|
||||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
bash "${SCRIPT_DIR}/killall_sglang.sh"
|
bash "${SCRIPT_DIR}/killall_sglang.sh"
|
||||||
|
|
||||||
@@ -16,13 +17,10 @@ rm -rf /usr/local/lib/python3.10/dist-packages/flashinfer*
|
|||||||
rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
|
rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
|
||||||
|
|
||||||
# Install the main package
|
# Install the main package
|
||||||
pip install -e "python[all]"
|
pip install -e "python[dev]"
|
||||||
|
|
||||||
# Install additional dependencies
|
# Install additional dependencies
|
||||||
pip install transformers==4.51.0 timm torchaudio==2.6.0 sentence_transformers accelerate peft pandas datasets mooncake-transfer-engine==0.3.0
|
pip install mooncake-transfer-engine==0.3.0 nvidia-cuda-nvrtc-cu12
|
||||||
|
|
||||||
# For compiling xgrammar kernels
|
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
|
||||||
|
|
||||||
# For lmms_evals evaluating MMMU
|
# For lmms_evals evaluating MMMU
|
||||||
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||||
|
|||||||
Reference in New Issue
Block a user