Update model_loader deps and qqq quantization deps (#2220) (#2318)

Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
Yineng Zhang
2024-12-02 23:22:13 +08:00
committed by GitHub
parent 33deca81b5
commit 85e1a6f3aa
58 changed files with 2363 additions and 366 deletions

View File

@@ -14,22 +14,12 @@
"""ModelRunner runs the forward passes of the models."""
import gc
import importlib
import importlib.resources
import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import (
get_tp_group,
init_distributed_environment,
@@ -37,9 +27,9 @@ from vllm.distributed import (
set_custom_all_reduce,
)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
@@ -228,49 +217,6 @@ class ModelRunner:
return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
pass
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -284,6 +230,7 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
@@ -292,23 +239,21 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
self.model = self.setup_model()
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.dtype = self.vllm_model_config.dtype
self.dtype = self.model_config.dtype
logger.info(
f"Load weight end. "
@@ -319,12 +264,12 @@ class ModelRunner:
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import (
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.utils import set_default_torch_dtype
logger.info(
f"Update engine weights online from disk begin. "
@@ -332,15 +277,7 @@ class ModelRunner:
)
target_device = torch.device(self.device)
try:
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message
self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now
@@ -352,7 +289,7 @@ class ModelRunner:
def get_weight_iter(config):
iter = loader._get_weights_iterator(
DefaultModelLoader.Source(
config.model,
config.model_path,
revision=config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
@@ -370,9 +307,9 @@ class ModelRunner:
quant_method.process_weights_after_loading(module)
return model
with set_default_torch_dtype(vllm_model_config.dtype):
with set_default_torch_dtype(self.model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
iter = get_weight_iter(self.model_config)
except Exception as e:
message = f"Failed to get weights iterator: {e}."
return False, message
@@ -384,16 +321,14 @@ class ModelRunner:
)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
iter = get_weight_iter(self.model_config)
self.model = model_load_weights(self.model, iter)
return False, message
self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
@@ -794,55 +729,3 @@ class ModelRunner:
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)