feat: update torch 2.5.1 (#2069)
This commit is contained in:
12
Makefile
Normal file
12
Makefile
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
.PHONY: check-deps install-deps format
|
||||||
|
|
||||||
|
check-deps:
|
||||||
|
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
|
||||||
|
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
|
||||||
|
|
||||||
|
install-deps:
|
||||||
|
pip install isort black
|
||||||
|
|
||||||
|
format: check-deps
|
||||||
|
@echo "Formatting modified Python files..."
|
||||||
|
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'
|
||||||
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
|
|||||||
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
|
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
|
||||||
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
|
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
|
||||||
"outlines>=0.0.44,<0.1.0", "modelscope"]
|
"outlines>=0.0.44,<0.1.0", "modelscope"]
|
||||||
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||||
|
|||||||
@@ -32,12 +32,14 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_silu_and_mul")
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_gelu_and_mul")
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
def __init__(self, approximate="tanh"):
|
def __init__(self, approximate="tanh"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
26
python/sglang/srt/layers/custom_op_util.py
Normal file
26
python/sglang/srt/layers/custom_op_util.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
|
||||||
|
def register_custom_op(op_name):
|
||||||
|
def decorator(cls):
|
||||||
|
if hasattr(CustomOp, "register"):
|
||||||
|
return CustomOp.register(op_name)(cls)
|
||||||
|
else:
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
@@ -33,9 +33,12 @@ if is_flashinfer_available():
|
|||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_rmsnorm")
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
|
|||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_gemma_rmsnorm")
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -90,6 +90,8 @@ def set_torch_compile_config():
|
|||||||
|
|
||||||
# FIXME: tmp workaround
|
# FIXME: tmp workaround
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||||
|
torch._dynamo.config.cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
@maybe_torch_compile(dynamic=True)
|
@maybe_torch_compile(dynamic=True)
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ limitations under the License.
|
|||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
|||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
|
monkey_patch_vllm_model_config,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -226,6 +227,47 @@ class ModelRunner:
|
|||||||
|
|
||||||
return min_per_gpu_memory
|
return min_per_gpu_memory
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
try:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.model_config = self.vllm_model_config
|
||||||
|
vllm_config.load_config = self.load_config
|
||||||
|
vllm_config.device_config = DeviceConfig(self.device)
|
||||||
|
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
||||||
|
vllm_config.model_config, vllm_config.load_config
|
||||||
|
)
|
||||||
|
return get_model(vllm_config=vllm_config)
|
||||||
|
except ImportError:
|
||||||
|
return get_model(
|
||||||
|
model_config=self.vllm_model_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
device_config=DeviceConfig(self.device),
|
||||||
|
parallel_config=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lora_config=None,
|
||||||
|
cache_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model_config_params(self):
|
||||||
|
sig = inspect.signature(VllmModelConfig.__init__)
|
||||||
|
params = {
|
||||||
|
"model": self.server_args.model_path,
|
||||||
|
"quantization": self.server_args.quantization,
|
||||||
|
"tokenizer": None,
|
||||||
|
"tokenizer_mode": None,
|
||||||
|
"trust_remote_code": self.server_args.trust_remote_code,
|
||||||
|
"dtype": self.server_args.dtype,
|
||||||
|
"seed": self.server_args.random_seed,
|
||||||
|
"skip_tokenizer_init": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "task" in sig.parameters:
|
||||||
|
params["task"] = ""
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
@@ -247,31 +289,15 @@ class ModelRunner:
|
|||||||
load_format=self.server_args.load_format,
|
load_format=self.server_args.load_format,
|
||||||
download_dir=self.server_args.download_dir,
|
download_dir=self.server_args.download_dir,
|
||||||
)
|
)
|
||||||
self.vllm_model_config = VllmModelConfig(
|
monkey_patch_vllm_model_config()
|
||||||
model=self.server_args.model_path,
|
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||||
quantization=self.server_args.quantization,
|
|
||||||
tokenizer=None,
|
|
||||||
tokenizer_mode=None,
|
|
||||||
trust_remote_code=self.server_args.trust_remote_code,
|
|
||||||
dtype=self.server_args.dtype,
|
|
||||||
seed=self.server_args.random_seed,
|
|
||||||
skip_tokenizer_init=True,
|
|
||||||
)
|
|
||||||
if self.model_config.model_override_args is not None:
|
if self.model_config.model_override_args is not None:
|
||||||
self.vllm_model_config.hf_config.update(
|
self.vllm_model_config.hf_config.update(
|
||||||
self.model_config.model_override_args
|
self.model_config.model_override_args
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the model
|
self.model = self.setup_model()
|
||||||
self.model = get_model(
|
|
||||||
model_config=self.vllm_model_config,
|
|
||||||
load_config=self.load_config,
|
|
||||||
device_config=DeviceConfig(self.device),
|
|
||||||
parallel_config=None,
|
|
||||||
scheduler_config=None,
|
|
||||||
lora_config=None,
|
|
||||||
cache_config=None,
|
|
||||||
)
|
|
||||||
self.sliding_window_size = (
|
self.sliding_window_size = (
|
||||||
self.model.get_attention_sliding_window_size()
|
self.model.get_attention_sliding_window_size()
|
||||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||||
@@ -303,17 +329,9 @@ class ModelRunner:
|
|||||||
target_device = torch.device(self.device)
|
target_device = torch.device(self.device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Use a better method to check this
|
model_config_params = self.get_model_config_params()
|
||||||
vllm_model_config = VllmModelConfig(
|
model_config_params["model"] = model_path
|
||||||
model=model_path,
|
vllm_model_config = VllmModelConfig(**model_config_params)
|
||||||
quantization=self.server_args.quantization,
|
|
||||||
tokenizer=None,
|
|
||||||
tokenizer_mode=None,
|
|
||||||
trust_remote_code=self.server_args.trust_remote_code,
|
|
||||||
dtype=self.server_args.dtype,
|
|
||||||
seed=self.server_args.random_seed,
|
|
||||||
skip_tokenizer_init=True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message = f"Failed to load model config: {e}."
|
message = f"Failed to load model config: {e}."
|
||||||
return False, message
|
return False, message
|
||||||
|
|||||||
@@ -332,6 +332,7 @@ def suppress_other_loggers():
|
|||||||
)
|
)
|
||||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||||
@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_vllm_model_config():
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
if not hasattr(ModelConfig, "_resolve_task"):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _resolve_task(
|
||||||
|
self,
|
||||||
|
task_option,
|
||||||
|
hf_config,
|
||||||
|
):
|
||||||
|
supported_tasks = {
|
||||||
|
"generate": True,
|
||||||
|
"embedding": False,
|
||||||
|
}
|
||||||
|
selected_task = "generate"
|
||||||
|
return supported_tasks, selected_task
|
||||||
|
|
||||||
|
setattr(ModelConfig, "_resolve_task", _resolve_task)
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||||
"""
|
"""
|
||||||
Monkey patch the slow p2p access check in vllm.
|
Monkey patch the slow p2p access check in vllm.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -529,6 +530,7 @@ def run_bench_serving(
|
|||||||
random_input_len=4096,
|
random_input_len=4096,
|
||||||
random_output_len=2048,
|
random_output_len=2048,
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
|
need_warmup=False,
|
||||||
):
|
):
|
||||||
# Launch the server
|
# Launch the server
|
||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
@@ -565,6 +567,10 @@ def run_bench_serving(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if need_warmup:
|
||||||
|
warmup_args = copy.deepcopy(args)
|
||||||
|
warmup_args.num_prompts = 16
|
||||||
|
run_benchmark(warmup_args)
|
||||||
res = run_benchmark(args)
|
res = run_benchmark(args)
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(process.pid, include_self=True)
|
kill_child_process(process.pid, include_self=True)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
random_input_len=None,
|
random_input_len=None,
|
||||||
random_output_len=None,
|
random_output_len=None,
|
||||||
disable_stream=True,
|
disable_stream=True,
|
||||||
|
need_warmup=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
|
|||||||
Reference in New Issue
Block a user