feat: update torch 2.5.1 (#2069)
This commit is contained in:
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
|
||||
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
|
||||
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
|
||||
"outlines>=0.0.44,<0.1.0", "modelscope"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||
|
||||
@@ -32,12 +32,14 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_custom_op("sglang_silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
@register_custom_op("sglang_gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
super().__init__()
|
||||
|
||||
26
python/sglang/srt/layers/custom_op_util.py
Normal file
26
python/sglang/srt/layers/custom_op_util.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
def register_custom_op(op_name):
|
||||
def decorator(cls):
|
||||
if hasattr(CustomOp, "register"):
|
||||
return CustomOp.register(op_name)(cls)
|
||||
else:
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -33,9 +33,12 @@ if is_flashinfer_available():
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_custom_op("sglang_rmsnorm")
|
||||
class RMSNorm(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
|
||||
return x, residual
|
||||
|
||||
|
||||
@register_custom_op("sglang_gemma_rmsnorm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -90,6 +90,8 @@ def set_torch_compile_config():
|
||||
|
||||
# FIXME: tmp workaround
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||
torch._dynamo.config.cache_size_limit = 1024
|
||||
|
||||
|
||||
@maybe_torch_compile(dynamic=True)
|
||||
|
||||
@@ -18,9 +18,9 @@ limitations under the License.
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
monkey_patch_vllm_model_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
)
|
||||
|
||||
@@ -226,6 +227,47 @@ class ModelRunner:
|
||||
|
||||
return min_per_gpu_memory
|
||||
|
||||
def setup_model(self):
|
||||
try:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.model_config = self.vllm_model_config
|
||||
vllm_config.load_config = self.load_config
|
||||
vllm_config.device_config = DeviceConfig(self.device)
|
||||
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
||||
vllm_config.model_config, vllm_config.load_config
|
||||
)
|
||||
return get_model(vllm_config=vllm_config)
|
||||
except ImportError:
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
|
||||
def get_model_config_params(self):
|
||||
sig = inspect.signature(VllmModelConfig.__init__)
|
||||
params = {
|
||||
"model": self.server_args.model_path,
|
||||
"quantization": self.server_args.quantization,
|
||||
"tokenizer": None,
|
||||
"tokenizer_mode": None,
|
||||
"trust_remote_code": self.server_args.trust_remote_code,
|
||||
"dtype": self.server_args.dtype,
|
||||
"seed": self.server_args.random_seed,
|
||||
"skip_tokenizer_init": True,
|
||||
}
|
||||
|
||||
if "task" in sig.parameters:
|
||||
params["task"] = ""
|
||||
|
||||
return params
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
@@ -247,31 +289,15 @@ class ModelRunner:
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
)
|
||||
self.vllm_model_config = VllmModelConfig(
|
||||
model=self.server_args.model_path,
|
||||
quantization=self.server_args.quantization,
|
||||
tokenizer=None,
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
monkey_patch_vllm_model_config()
|
||||
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||
if self.model_config.model_override_args is not None:
|
||||
self.vllm_model_config.hf_config.update(
|
||||
self.model_config.model_override_args
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
self.model = self.setup_model()
|
||||
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
@@ -303,17 +329,9 @@ class ModelRunner:
|
||||
target_device = torch.device(self.device)
|
||||
|
||||
try:
|
||||
# TODO: Use a better method to check this
|
||||
vllm_model_config = VllmModelConfig(
|
||||
model=model_path,
|
||||
quantization=self.server_args.quantization,
|
||||
tokenizer=None,
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
model_config_params = self.get_model_config_params()
|
||||
model_config_params["model"] = model_path
|
||||
vllm_model_config = VllmModelConfig(**model_config_params)
|
||||
except Exception as e:
|
||||
message = f"Failed to load model config: {e}."
|
||||
return False, message
|
||||
|
||||
@@ -332,6 +332,7 @@ def suppress_other_loggers():
|
||||
)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
||||
pass
|
||||
|
||||
|
||||
def monkey_patch_vllm_model_config():
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
if not hasattr(ModelConfig, "_resolve_task"):
|
||||
return
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option,
|
||||
hf_config,
|
||||
):
|
||||
supported_tasks = {
|
||||
"generate": True,
|
||||
"embedding": False,
|
||||
}
|
||||
selected_task = "generate"
|
||||
return supported_tasks, selected_task
|
||||
|
||||
setattr(ModelConfig, "_resolve_task", _resolve_task)
|
||||
|
||||
|
||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
"""
|
||||
Monkey patch the slow p2p access check in vllm.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
@@ -529,6 +530,7 @@ def run_bench_serving(
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
disable_stream=False,
|
||||
need_warmup=False,
|
||||
):
|
||||
# Launch the server
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
@@ -565,6 +567,10 @@ def run_bench_serving(
|
||||
)
|
||||
|
||||
try:
|
||||
if need_warmup:
|
||||
warmup_args = copy.deepcopy(args)
|
||||
warmup_args.num_prompts = 16
|
||||
run_benchmark(warmup_args)
|
||||
res = run_benchmark(args)
|
||||
finally:
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
Reference in New Issue
Block a user