[minor] Improve code style and compatibility (#1961)
This commit is contained in:
@@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
|
||||
"torchao", "uvicorn", "uvloop", "zmq",
|
||||
"outlines>=0.0.44", "modelscope"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
|
||||
|
||||
@@ -461,7 +461,7 @@ class TokenizerManager:
|
||||
break
|
||||
|
||||
kill_child_process(include_self=True)
|
||||
sys.exit(-1)
|
||||
sys.exit(0)
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -92,7 +92,7 @@ def set_torch_compile_config():
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
@maybe_torch_compile(dynamic=True)
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
is_port_available,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -97,8 +98,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -107,6 +106,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
@@ -275,6 +278,9 @@ app.post("/classify")(classify_request)
|
||||
app.put("/classify")(classify_request)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(tokenizer_manager, raw_request)
|
||||
@@ -420,19 +426,6 @@ def launch_engine(
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
|
||||
def add_prometheus_middleware(app: FastAPI):
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||
@@ -492,6 +485,19 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
def add_prometheus_middleware(app: FastAPI):
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def _set_prometheus_env():
|
||||
# Set prometheus multiprocess directory
|
||||
# sglang uses prometheus multiprocess mode
|
||||
@@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
return
|
||||
|
||||
model_info = res.json()
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||
@@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
|
||||
@@ -63,7 +63,7 @@ class ServerArgs:
|
||||
stream_interval: int = 1
|
||||
random_seed: Optional[int] = None
|
||||
constrained_json_whitespace_pattern: Optional[str] = None
|
||||
decode_log_interval: int = 40
|
||||
watchdog_timeout: float = 300
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
@@ -71,18 +71,18 @@ class ServerArgs:
|
||||
log_requests: bool = False
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
decode_log_interval: int = 40
|
||||
|
||||
# Other
|
||||
# API related
|
||||
api_key: Optional[str] = None
|
||||
file_storage_pth: str = "SGLang_storage"
|
||||
enable_cache_report: bool = False
|
||||
watchdog_timeout: float = 600
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Distributed args
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: int = 0
|
||||
@@ -128,6 +128,7 @@ class ServerArgs:
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
num_continuous_decode_steps: int = 1
|
||||
delete_ckpt_after_loading: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -205,6 +206,7 @@ class ServerArgs:
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
@@ -324,6 +326,8 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
|
||||
# Memory and scheduling
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
@@ -368,6 +372,8 @@ class ServerArgs:
|
||||
default=ServerArgs.schedule_conservativeness,
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
"--tp-size",
|
||||
@@ -393,6 +399,14 @@ class ServerArgs:
|
||||
default=ServerArgs.constrained_json_whitespace_pattern,
|
||||
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--watchdog-timeout",
|
||||
type=float,
|
||||
default=ServerArgs.watchdog_timeout,
|
||||
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
@@ -420,7 +434,14 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable log prometheus metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-log-interval",
|
||||
type=int,
|
||||
default=ServerArgs.decode_log_interval,
|
||||
help="The log interval of decode batch",
|
||||
)
|
||||
|
||||
# API related
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
@@ -438,18 +459,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--watchdog-timeout",
|
||||
type=float,
|
||||
default=ServerArgs.watchdog_timeout,
|
||||
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-log-interval",
|
||||
type=int,
|
||||
default=ServerArgs.decode_log_interval,
|
||||
help="The log interval of decode batch",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
@@ -470,7 +479,7 @@ class ServerArgs:
|
||||
],
|
||||
)
|
||||
|
||||
# Multi-node distributed serving args
|
||||
# Multi-node distributed serving
|
||||
parser.add_argument(
|
||||
"--dist-init-addr",
|
||||
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
||||
@@ -677,6 +686,12 @@ class ServerArgs:
|
||||
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
||||
"The default value is 1, meaning only run one decoding step at a time.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-ckpt-after-loading",
|
||||
default=ServerArgs.delete_ckpt_after_loading,
|
||||
action="store_true",
|
||||
help="Delete the model checkpoint after loading the model.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -23,6 +23,8 @@ import os
|
||||
import pickle
|
||||
import random
|
||||
import resource
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
@@ -35,6 +37,7 @@ import psutil
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import zmq
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
||||
if include_self:
|
||||
try:
|
||||
itself.kill()
|
||||
|
||||
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||
# so we send an additional signal to kill them.
|
||||
itself.send_signal(signal.SIGINT)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
def dump_to_file(dirpath, name, value):
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() != 0:
|
||||
return
|
||||
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
if value.dtype is torch.bfloat16:
|
||||
value = value.float()
|
||||
value = value.cpu().numpy()
|
||||
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
|
||||
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
|
||||
np.save(output_filename, value)
|
||||
|
||||
|
||||
def is_triton_3():
|
||||
return triton.__version__.startswith("3.")
|
||||
|
||||
|
||||
def maybe_torch_compile(*args, **kwargs):
|
||||
"""
|
||||
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
|
||||
Therefore, we disable it here.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
if is_triton_3():
|
||||
return torch.compile(*args, **kwargs)(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def delete_directory(dirpath):
|
||||
try:
|
||||
# This will remove the directory and all its contents
|
||||
shutil.rmtree(dirpath)
|
||||
except OSError as e:
|
||||
print(f"Warning: {dirpath} : {e.strerror}")
|
||||
|
||||
Reference in New Issue
Block a user