feat: update torch 2.5.1 (#2069)

This commit is contained in:
Yineng Zhang
2024-11-18 21:29:13 +08:00
committed by GitHub
parent 2a3992b6f1
commit 766192610e
10 changed files with 127 additions and 33 deletions

View File

@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
"outlines>=0.0.44,<0.1.0", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl

View File

@@ -32,12 +32,14 @@ from vllm.distributed import (
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()

View File

@@ -0,0 +1,26 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator

View File

@@ -33,9 +33,12 @@ if is_flashinfer_available():
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp):
def __init__(
self,
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp):
def __init__(
self,

View File

@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True)

View File

@@ -18,9 +18,9 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import inspect
import json
import logging
import os
import pkgutil
from functools import lru_cache
from typing import Optional, Type
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
)
@@ -226,6 +227,47 @@ class ModelRunner:
return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -247,31 +289,15 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
monkey_patch_vllm_model_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
# Load the model
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
self.model = self.setup_model()
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
@@ -303,17 +329,9 @@ class ModelRunner:
target_device = torch.device(self.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message

View File

@@ -332,6 +332,7 @@ def suppress_other_loggers():
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
Monkey patch the slow p2p access check in vllm.

View File

@@ -2,6 +2,7 @@
import argparse
import asyncio
import copy
import os
import random
import subprocess
@@ -529,6 +530,7 @@ def run_bench_serving(
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
need_warmup=False,
):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
@@ -565,6 +567,10 @@ def run_bench_serving(
)
try:
if need_warmup:
warmup_args = copy.deepcopy(args)
warmup_args.num_prompts = 16
run_benchmark(warmup_args)
res = run_benchmark(args)
finally:
kill_child_process(process.pid, include_self=True)