feat: update torch 2.5.1 (#2069)

This commit is contained in:
Yineng Zhang
2024-11-18 21:29:13 +08:00
committed by GitHub
parent 2a3992b6f1
commit 766192610e
10 changed files with 127 additions and 33 deletions

View File

@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True)

View File

@@ -18,9 +18,9 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import inspect
import json
import logging
import os
import pkgutil
from functools import lru_cache
from typing import Optional, Type
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
)
@@ -226,6 +227,47 @@ class ModelRunner:
return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -247,31 +289,15 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
monkey_patch_vllm_model_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
# Load the model
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
self.model = self.setup_model()
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
@@ -303,17 +329,9 @@ class ModelRunner:
target_device = torch.device(self.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message