feat: update torch 2.5.1 (#2069)
This commit is contained in:
@@ -90,6 +90,8 @@ def set_torch_compile_config():
|
||||
|
||||
# FIXME: tmp workaround
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||
torch._dynamo.config.cache_size_limit = 1024
|
||||
|
||||
|
||||
@maybe_torch_compile(dynamic=True)
|
||||
|
||||
@@ -18,9 +18,9 @@ limitations under the License.
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
monkey_patch_vllm_model_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
)
|
||||
|
||||
@@ -226,6 +227,47 @@ class ModelRunner:
|
||||
|
||||
return min_per_gpu_memory
|
||||
|
||||
def setup_model(self):
|
||||
try:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.model_config = self.vllm_model_config
|
||||
vllm_config.load_config = self.load_config
|
||||
vllm_config.device_config = DeviceConfig(self.device)
|
||||
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
||||
vllm_config.model_config, vllm_config.load_config
|
||||
)
|
||||
return get_model(vllm_config=vllm_config)
|
||||
except ImportError:
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
|
||||
def get_model_config_params(self):
|
||||
sig = inspect.signature(VllmModelConfig.__init__)
|
||||
params = {
|
||||
"model": self.server_args.model_path,
|
||||
"quantization": self.server_args.quantization,
|
||||
"tokenizer": None,
|
||||
"tokenizer_mode": None,
|
||||
"trust_remote_code": self.server_args.trust_remote_code,
|
||||
"dtype": self.server_args.dtype,
|
||||
"seed": self.server_args.random_seed,
|
||||
"skip_tokenizer_init": True,
|
||||
}
|
||||
|
||||
if "task" in sig.parameters:
|
||||
params["task"] = ""
|
||||
|
||||
return params
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
@@ -247,31 +289,15 @@ class ModelRunner:
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
)
|
||||
self.vllm_model_config = VllmModelConfig(
|
||||
model=self.server_args.model_path,
|
||||
quantization=self.server_args.quantization,
|
||||
tokenizer=None,
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
monkey_patch_vllm_model_config()
|
||||
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||
if self.model_config.model_override_args is not None:
|
||||
self.vllm_model_config.hf_config.update(
|
||||
self.model_config.model_override_args
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
self.model = self.setup_model()
|
||||
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
@@ -303,17 +329,9 @@ class ModelRunner:
|
||||
target_device = torch.device(self.device)
|
||||
|
||||
try:
|
||||
# TODO: Use a better method to check this
|
||||
vllm_model_config = VllmModelConfig(
|
||||
model=model_path,
|
||||
quantization=self.server_args.quantization,
|
||||
tokenizer=None,
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
model_config_params = self.get_model_config_params()
|
||||
model_config_params["model"] = model_path
|
||||
vllm_model_config = VllmModelConfig(**model_config_params)
|
||||
except Exception as e:
|
||||
message = f"Failed to load model config: {e}."
|
||||
return False, message
|
||||
|
||||
Reference in New Issue
Block a user