[minor] Improve code style and compatibility (#1961)
This commit is contained in:
@@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
|
|||||||
"torchao", "uvicorn", "uvloop", "zmq",
|
"torchao", "uvicorn", "uvloop", "zmq",
|
||||||
"outlines>=0.0.44", "modelscope"]
|
"outlines>=0.0.44", "modelscope"]
|
||||||
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||||
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
|
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ class TokenizerManager:
|
|||||||
break
|
break
|
||||||
|
|
||||||
kill_child_process(include_self=True)
|
kill_child_process(include_self=True)
|
||||||
sys.exit(-1)
|
sys.exit(0)
|
||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
|
|||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -92,7 +92,7 @@ def set_torch_compile_config():
|
|||||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@maybe_torch_compile(dynamic=True)
|
||||||
def clamp_position(seq_lens):
|
def clamp_position(seq_lens):
|
||||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
|||||||
add_api_key_middleware,
|
add_api_key_middleware,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
|
delete_directory,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
@@ -97,8 +98,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
tokenizer_manager: TokenizerManager = None
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@@ -107,6 +106,10 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tokenizer_manager: TokenizerManager = None
|
||||||
|
|
||||||
|
##### Native API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
@@ -275,6 +278,9 @@ app.post("/classify")(classify_request)
|
|||||||
app.put("/classify")(classify_request)
|
app.put("/classify")(classify_request)
|
||||||
|
|
||||||
|
|
||||||
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(raw_request: Request):
|
||||||
return await v1_completions(tokenizer_manager, raw_request)
|
return await v1_completions(tokenizer_manager, raw_request)
|
||||||
@@ -420,19 +426,6 @@ def launch_engine(
|
|||||||
scheduler_pipe_readers[i].recv()
|
scheduler_pipe_readers[i].recv()
|
||||||
|
|
||||||
|
|
||||||
def add_prometheus_middleware(app: FastAPI):
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
|
||||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
|
||||||
|
|
||||||
registry = CollectorRegistry()
|
|
||||||
multiprocess.MultiProcessCollector(registry)
|
|
||||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
|
||||||
|
|
||||||
# Workaround for 307 Redirect for /metrics
|
|
||||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
|
||||||
app.routes.append(metrics_route)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||||
@@ -492,6 +485,19 @@ def launch_server(
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def add_prometheus_middleware(app: FastAPI):
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||||
|
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||||
|
|
||||||
|
registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
|
||||||
|
# Workaround for 307 Redirect for /metrics
|
||||||
|
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||||
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
def _set_prometheus_env():
|
def _set_prometheus_env():
|
||||||
# Set prometheus multiprocess directory
|
# Set prometheus multiprocess directory
|
||||||
# sglang uses prometheus multiprocess mode
|
# sglang uses prometheus multiprocess mode
|
||||||
@@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
model_info = res.json()
|
model_info = res.json()
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||||
@@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send("ready")
|
pipe_finish_writer.send("ready")
|
||||||
|
|
||||||
|
if server_args.delete_ckpt_after_loading:
|
||||||
|
delete_directory(server_args.model_path)
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class ServerArgs:
|
|||||||
stream_interval: int = 1
|
stream_interval: int = 1
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
constrained_json_whitespace_pattern: Optional[str] = None
|
constrained_json_whitespace_pattern: Optional[str] = None
|
||||||
decode_log_interval: int = 40
|
watchdog_timeout: float = 300
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
@@ -71,18 +71,18 @@ class ServerArgs:
|
|||||||
log_requests: bool = False
|
log_requests: bool = False
|
||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
enable_metrics: bool = False
|
enable_metrics: bool = False
|
||||||
|
decode_log_interval: int = 40
|
||||||
|
|
||||||
# Other
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "SGLang_storage"
|
file_storage_pth: str = "SGLang_storage"
|
||||||
enable_cache_report: bool = False
|
enable_cache_report: bool = False
|
||||||
watchdog_timeout: float = 600
|
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
# Distributed args
|
# Multi-node distributed serving
|
||||||
dist_init_addr: Optional[str] = None
|
dist_init_addr: Optional[str] = None
|
||||||
nnodes: int = 1
|
nnodes: int = 1
|
||||||
node_rank: int = 0
|
node_rank: int = 0
|
||||||
@@ -128,6 +128,7 @@ class ServerArgs:
|
|||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
num_continuous_decode_steps: int = 1
|
num_continuous_decode_steps: int = 1
|
||||||
|
delete_ckpt_after_loading: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -205,6 +206,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
# Model and port args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-path",
|
"--model-path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -324,6 +326,8 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use a CausalLM as an embedding model.",
|
help="Whether to use a CausalLM as an embedding model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Memory and scheduling
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -368,6 +372,8 @@ class ServerArgs:
|
|||||||
default=ServerArgs.schedule_conservativeness,
|
default=ServerArgs.schedule_conservativeness,
|
||||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Other runtime options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tensor-parallel-size",
|
"--tensor-parallel-size",
|
||||||
"--tp-size",
|
"--tp-size",
|
||||||
@@ -393,6 +399,14 @@ class ServerArgs:
|
|||||||
default=ServerArgs.constrained_json_whitespace_pattern,
|
default=ServerArgs.constrained_json_whitespace_pattern,
|
||||||
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--watchdog-timeout",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.watchdog_timeout,
|
||||||
|
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logging
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -420,7 +434,14 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable log prometheus metrics.",
|
help="Enable log prometheus metrics.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-log-interval",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.decode_log_interval,
|
||||||
|
help="The log interval of decode batch",
|
||||||
|
)
|
||||||
|
|
||||||
|
# API related
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--api-key",
|
"--api-key",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -438,18 +459,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--watchdog-timeout",
|
|
||||||
type=float,
|
|
||||||
default=ServerArgs.watchdog_timeout,
|
|
||||||
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-log-interval",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.decode_log_interval,
|
|
||||||
help="The log interval of decode batch",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -470,7 +479,7 @@ class ServerArgs:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multi-node distributed serving args
|
# Multi-node distributed serving
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dist-init-addr",
|
"--dist-init-addr",
|
||||||
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
||||||
@@ -677,6 +686,12 @@ class ServerArgs:
|
|||||||
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
||||||
"The default value is 1, meaning only run one decoding step at a time.",
|
"The default value is 1, meaning only run one decoding step at a time.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--delete-ckpt-after-loading",
|
||||||
|
default=ServerArgs.delete_ckpt_after_loading,
|
||||||
|
action="store_true",
|
||||||
|
help="Delete the model checkpoint after loading the model.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import resource
|
import resource
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
@@ -35,6 +37,7 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import triton
|
||||||
import zmq
|
import zmq
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
|||||||
if include_self:
|
if include_self:
|
||||||
try:
|
try:
|
||||||
itself.kill()
|
itself.kill()
|
||||||
|
|
||||||
|
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||||
|
# so we send an additional signal to kill them.
|
||||||
|
itself.send_signal(signal.SIGINT)
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
|||||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
def dump_to_file(dirpath, name, value):
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(dirpath, exist_ok=True)
|
||||||
|
if value.dtype is torch.bfloat16:
|
||||||
|
value = value.float()
|
||||||
|
value = value.cpu().numpy()
|
||||||
|
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
|
||||||
|
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
|
||||||
|
np.save(output_filename, value)
|
||||||
|
|
||||||
|
|
||||||
|
def is_triton_3():
|
||||||
|
return triton.__version__.startswith("3.")
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_torch_compile(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
|
||||||
|
Therefore, we disable it here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
if is_triton_3():
|
||||||
|
return torch.compile(*args, **kwargs)(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def delete_directory(dirpath):
|
||||||
|
try:
|
||||||
|
# This will remove the directory and all its contents
|
||||||
|
shutil.rmtree(dirpath)
|
||||||
|
except OSError as e:
|
||||||
|
print(f"Warning: {dirpath} : {e.strerror}")
|
||||||
|
|||||||
Reference in New Issue
Block a user