Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -14,22 +14,12 @@
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
@@ -37,9 +27,9 @@ from vllm.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_model_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
@@ -228,49 +217,6 @@ class ModelRunner:
|
||||
|
||||
return min_per_gpu_memory
|
||||
|
||||
def setup_model(self):
|
||||
try:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.model_config = self.vllm_model_config
|
||||
vllm_config.load_config = self.load_config
|
||||
vllm_config.device_config = DeviceConfig(self.device)
|
||||
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
||||
vllm_config.model_config, vllm_config.load_config
|
||||
)
|
||||
return get_model(vllm_config=vllm_config)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
|
||||
def get_model_config_params(self):
|
||||
sig = inspect.signature(VllmModelConfig.__init__)
|
||||
params = {
|
||||
"model": self.server_args.model_path,
|
||||
"quantization": self.server_args.quantization,
|
||||
"tokenizer": None,
|
||||
"tokenizer_mode": None,
|
||||
"trust_remote_code": self.server_args.trust_remote_code,
|
||||
"dtype": self.server_args.dtype,
|
||||
"seed": self.server_args.random_seed,
|
||||
"skip_tokenizer_init": True,
|
||||
}
|
||||
|
||||
if "task" in sig.parameters:
|
||||
params["task"] = ""
|
||||
|
||||
return params
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
@@ -284,6 +230,7 @@ class ModelRunner:
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
self.model_config.dtype = torch.float16
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
@@ -292,23 +239,21 @@ class ModelRunner:
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
)
|
||||
monkey_patch_vllm_model_config()
|
||||
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||
if self.model_config.model_override_args is not None:
|
||||
self.vllm_model_config.hf_config.update(
|
||||
self.model_config.model_override_args
|
||||
)
|
||||
|
||||
self.model = self.setup_model()
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.dtype = self.vllm_model_config.dtype
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
logger.info(
|
||||
f"Load weight end. "
|
||||
@@ -319,12 +264,12 @@ class ModelRunner:
|
||||
|
||||
def update_weights_from_disk(self, model_path: str, load_format: str):
|
||||
"""Update engine weights online from disk."""
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
from sglang.srt.model_loader.loader import (
|
||||
DefaultModelLoader,
|
||||
device_loading_context,
|
||||
get_model_loader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
|
||||
logger.info(
|
||||
f"Update engine weights online from disk begin. "
|
||||
@@ -332,15 +277,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
target_device = torch.device(self.device)
|
||||
|
||||
try:
|
||||
model_config_params = self.get_model_config_params()
|
||||
model_config_params["model"] = model_path
|
||||
vllm_model_config = VllmModelConfig(**model_config_params)
|
||||
except Exception as e:
|
||||
message = f"Failed to load model config: {e}."
|
||||
return False, message
|
||||
|
||||
self.model_config.model_path = model_path
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support vllm DefaultModelLoader for now
|
||||
@@ -352,7 +289,7 @@ class ModelRunner:
|
||||
def get_weight_iter(config):
|
||||
iter = loader._get_weights_iterator(
|
||||
DefaultModelLoader.Source(
|
||||
config.model,
|
||||
config.model_path,
|
||||
revision=config.revision,
|
||||
fall_back_to_pt=getattr(
|
||||
self.model, "fall_back_to_pt_during_load", True
|
||||
@@ -370,9 +307,9 @@ class ModelRunner:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
return model
|
||||
|
||||
with set_default_torch_dtype(vllm_model_config.dtype):
|
||||
with set_default_torch_dtype(self.model_config.dtype):
|
||||
try:
|
||||
iter = get_weight_iter(vllm_model_config)
|
||||
iter = get_weight_iter(self.model_config)
|
||||
except Exception as e:
|
||||
message = f"Failed to get weights iterator: {e}."
|
||||
return False, message
|
||||
@@ -384,16 +321,14 @@ class ModelRunner:
|
||||
)
|
||||
del iter
|
||||
gc.collect()
|
||||
iter = get_weight_iter(self.vllm_model_config)
|
||||
iter = get_weight_iter(self.model_config)
|
||||
self.model = model_load_weights(self.model, iter)
|
||||
return False, message
|
||||
|
||||
self.model = model
|
||||
self.server_args.model_path = model_path
|
||||
self.server_args.load_format = load_format
|
||||
self.vllm_model_config = vllm_model_config
|
||||
self.load_config = load_config
|
||||
self.model_config.path = model_path
|
||||
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
@@ -794,55 +729,3 @@ class ModelRunner:
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
model_arch_name_to_cls = {}
|
||||
package_name = "sglang.srt.models"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}. {e}")
|
||||
if crash_on_warnings():
|
||||
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
||||
continue
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
if isinstance(
|
||||
entry, list
|
||||
): # To support multiple model classes in one module
|
||||
for tmp in entry:
|
||||
assert (
|
||||
tmp.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {tmp.__name__}"
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
assert (
|
||||
entry.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {entry.__name__}"
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
model_arch_name_to_cls = import_model_classes()
|
||||
|
||||
if model_arch not in model_arch_name_to_cls:
|
||||
raise ValueError(
|
||||
f"Unsupported architectures: {model_arch}. "
|
||||
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
||||
)
|
||||
return model_arch_name_to_cls[model_arch]
|
||||
|
||||
|
||||
# Monkey patch model loader
|
||||
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
||||
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
||||
|
||||
Reference in New Issue
Block a user