diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..2c9ea886e --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +.PHONY: check-deps install-deps format + +check-deps: + @command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort) + @command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black) + +install-deps: + pip install isort black + +format: check-deps + @echo "Formatting modified Python files..." + git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}' diff --git a/python/pyproject.toml b/python/pyproject.toml index 5e144f809..0fe373106 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", "outlines>=0.0.44,<0.1.0", "modelscope"] -srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] +srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 94d48e82b..6597ae215 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -32,12 +32,14 @@ from vllm.distributed import ( ) from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) +@register_custom_op("sglang_silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -51,6 +53,7 @@ class SiluAndMul(CustomOp): return out +@register_custom_op("sglang_gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() diff --git a/python/sglang/srt/layers/custom_op_util.py b/python/sglang/srt/layers/custom_op_util.py new file mode 100644 index 000000000..3e790b273 --- /dev/null +++ b/python/sglang/srt/layers/custom_op_util.py @@ -0,0 +1,26 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from vllm.model_executor.custom_op import CustomOp + + +def register_custom_op(op_name): + def decorator(cls): + if hasattr(CustomOp, "register"): + return CustomOp.register(op_name)(cls) + else: + return cls + + return decorator diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3ae392eb9..cf252f321 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -33,9 +33,12 @@ if is_flashinfer_available(): from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.custom_op_util import register_custom_op + logger = logging.getLogger(__name__) +@register_custom_op("sglang_rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -78,6 +81,7 @@ class RMSNorm(CustomOp): return x, residual +@register_custom_op("sglang_gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index db185599f..91c6603a2 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -90,6 +90,8 @@ def set_torch_compile_config(): # FIXME: tmp workaround torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 @maybe_torch_compile(dynamic=True) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3c76f5ad7..036be8675 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,9 +18,9 @@ limitations under the License. import gc import importlib import importlib.resources +import inspect import json import logging -import os import pkgutil from functools import lru_cache from typing import Optional, Type @@ -60,6 +60,7 @@ from sglang.srt.utils import ( crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, + monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, ) @@ -226,6 +227,47 @@ class ModelRunner: return min_per_gpu_memory + def setup_model(self): + try: + from vllm.config import VllmConfig + + vllm_config = VllmConfig() + vllm_config.model_config = self.vllm_model_config + vllm_config.load_config = self.load_config + vllm_config.device_config = DeviceConfig(self.device) + vllm_config.quant_config = VllmConfig._get_quantization_config( + vllm_config.model_config, vllm_config.load_config + ) + return get_model(vllm_config=vllm_config) + except ImportError: + return get_model( + model_config=self.vllm_model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + parallel_config=None, + scheduler_config=None, + lora_config=None, + cache_config=None, + ) + + def get_model_config_params(self): + sig = inspect.signature(VllmModelConfig.__init__) + params = { + "model": self.server_args.model_path, + "quantization": self.server_args.quantization, + "tokenizer": None, + "tokenizer_mode": None, + "trust_remote_code": self.server_args.trust_remote_code, + "dtype": self.server_args.dtype, + "seed": self.server_args.random_seed, + "skip_tokenizer_init": True, + } + + if "task" in sig.parameters: + params["task"] = "" + + return params + def load_model(self): logger.info( f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" @@ -247,31 +289,15 @@ class ModelRunner: load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) - self.vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, - quantization=self.server_args.quantization, - tokenizer=None, - tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, - seed=self.server_args.random_seed, - skip_tokenizer_init=True, - ) + monkey_patch_vllm_model_config() + self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( self.model_config.model_override_args ) - # Load the model - self.model = get_model( - model_config=self.vllm_model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - parallel_config=None, - scheduler_config=None, - lora_config=None, - cache_config=None, - ) + self.model = self.setup_model() + self.sliding_window_size = ( self.model.get_attention_sliding_window_size() if hasattr(self.model, "get_attention_sliding_window_size") @@ -303,17 +329,9 @@ class ModelRunner: target_device = torch.device(self.device) try: - # TODO: Use a better method to check this - vllm_model_config = VllmModelConfig( - model=model_path, - quantization=self.server_args.quantization, - tokenizer=None, - tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, - seed=self.server_args.random_seed, - skip_tokenizer_init=True, - ) + model_config_params = self.get_model_config_params() + model_config_params["model"] = model_path + vllm_model_config = VllmModelConfig(**model_config_params) except Exception as e: message = f"Failed to load model config: {e}." return False, message diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 84bf9a2e5..d177a0bf8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -332,6 +332,7 @@ def suppress_other_loggers(): ) logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.ERROR) + logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR) warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" @@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): pass +def monkey_patch_vllm_model_config(): + from vllm.config import ModelConfig + + if not hasattr(ModelConfig, "_resolve_task"): + return + + def _resolve_task( + self, + task_option, + hf_config, + ): + supported_tasks = { + "generate": True, + "embedding": False, + } + selected_task = "generate" + return supported_tasks, selected_task + + setattr(ModelConfig, "_resolve_task", _resolve_task) + + def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ Monkey patch the slow p2p access check in vllm. diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 5d878d6af..12cbbd883 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -2,6 +2,7 @@ import argparse import asyncio +import copy import os import random import subprocess @@ -529,6 +530,7 @@ def run_bench_serving( random_input_len=4096, random_output_len=2048, disable_stream=False, + need_warmup=False, ): # Launch the server base_url = DEFAULT_URL_FOR_TEST @@ -565,6 +567,10 @@ def run_bench_serving( ) try: + if need_warmup: + warmup_args = copy.deepcopy(args) + warmup_args.num_prompts = 16 + run_benchmark(warmup_args) res = run_benchmark(args) finally: kill_child_process(process.pid, include_self=True) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index c3c6a7d13..ff4758633 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase): random_input_len=None, random_output_len=None, disable_stream=True, + need_warmup=True, ) if is_in_ci():