feat: update torch 2.5.1 (#2069)

This commit is contained in:
Yineng Zhang
2024-11-18 21:29:13 +08:00
committed by GitHub
parent 2a3992b6f1
commit 766192610e
10 changed files with 127 additions and 33 deletions

12
Makefile Normal file
View File

@@ -0,0 +1,12 @@
.PHONY: check-deps install-deps format
check-deps:
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
install-deps:
pip install isort black
format: check-deps
@echo "Formatting modified Python files..."
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'

View File

@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
"outlines>=0.0.44,<0.1.0", "modelscope"] "outlines>=0.0.44,<0.1.0", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20241022, not from public vllm whl

View File

@@ -32,12 +32,14 @@ from vllm.distributed import (
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
return out return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"): def __init__(self, approximate="tanh"):
super().__init__() super().__init__()

View File

@@ -0,0 +1,26 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator

View File

@@ -33,9 +33,12 @@ if is_flashinfer_available():
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
def __init__( def __init__(
self, self,
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
return x, residual return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
self, self,

View File

@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround # FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024 torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True) @maybe_torch_compile(dynamic=True)

View File

@@ -18,9 +18,9 @@ limitations under the License.
import gc import gc
import importlib import importlib
import importlib.resources import importlib.resources
import inspect
import json import json
import logging import logging
import os
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Type
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
) )
@@ -226,6 +227,47 @@ class ModelRunner:
return min_per_gpu_memory return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self): def load_model(self):
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -247,31 +289,15 @@ class ModelRunner:
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
self.vllm_model_config = VllmModelConfig( monkey_patch_vllm_model_config()
model=self.server_args.model_path, self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
if self.model_config.model_override_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(
self.model_config.model_override_args self.model_config.model_override_args
) )
# Load the model self.model = self.setup_model()
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
@@ -303,17 +329,9 @@ class ModelRunner:
target_device = torch.device(self.device) target_device = torch.device(self.device)
try: try:
# TODO: Use a better method to check this model_config_params = self.get_model_config_params()
vllm_model_config = VllmModelConfig( model_config_params["model"] = model_path
model=model_path, vllm_model_config = VllmModelConfig(**model_config_params)
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
except Exception as e: except Exception as e:
message = f"Failed to load model config: {e}." message = f"Failed to load model config: {e}."
return False, message return False, message

View File

@@ -332,6 +332,7 @@ def suppress_other_loggers():
) )
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR) logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings( warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable" "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
pass pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int): def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
Monkey patch the slow p2p access check in vllm. Monkey patch the slow p2p access check in vllm.

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import asyncio import asyncio
import copy
import os import os
import random import random
import subprocess import subprocess
@@ -529,6 +530,7 @@ def run_bench_serving(
random_input_len=4096, random_input_len=4096,
random_output_len=2048, random_output_len=2048,
disable_stream=False, disable_stream=False,
need_warmup=False,
): ):
# Launch the server # Launch the server
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
@@ -565,6 +567,10 @@ def run_bench_serving(
) )
try: try:
if need_warmup:
warmup_args = copy.deepcopy(args)
warmup_args.num_prompts = 16
run_benchmark(warmup_args)
res = run_benchmark(args) res = run_benchmark(args)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)

View File

@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase):
random_input_len=None, random_input_len=None,
random_output_len=None, random_output_len=None,
disable_stream=True, disable_stream=True,
need_warmup=True,
) )
if is_in_ci(): if is_in_ci():