From 85e1a6f3aa5a2288ca85fe3fe922c733b6533fa7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 2 Dec 2024 23:22:13 +0800 Subject: [PATCH] Update model_loader deps and qqq quantization deps (#2220) (#2318) Co-authored-by: HandH1998 <1335248067@qq.com> --- python/sglang/bench_one_batch.py | 4 + python/sglang/srt/configs/device_config.py | 17 + python/sglang/srt/configs/load_config.py | 84 ++ python/sglang/srt/configs/model_config.py | 165 ++- python/sglang/srt/configs/qwen2vl.py | 13 +- python/sglang/srt/hf_transformers_utils.py | 2 + python/sglang/srt/layers/linear.py | 1 + python/sglang/srt/lora/lora.py | 2 +- python/sglang/srt/managers/scheduler.py | 3 + .../sglang/srt/managers/tokenizer_manager.py | 3 + python/sglang/srt/managers/tp_worker.py | 3 + .../sglang/srt/model_executor/model_runner.py | 155 +-- python/sglang/srt/model_loader/__init__.py | 34 + python/sglang/srt/model_loader/loader.py | 1139 +++++++++++++++++ python/sglang/srt/model_loader/utils.py | 41 + .../sglang/srt/model_loader/weight_utils.py | 640 +++++++++ python/sglang/srt/models/baichuan.py | 8 +- python/sglang/srt/models/chatglm.py | 19 +- python/sglang/srt/models/commandr.py | 3 +- python/sglang/srt/models/dbrx.py | 3 +- python/sglang/srt/models/deepseek.py | 13 +- python/sglang/srt/models/deepseek_v2.py | 12 +- python/sglang/srt/models/exaone.py | 3 +- python/sglang/srt/models/gemma.py | 6 +- python/sglang/srt/models/gemma2.py | 15 +- python/sglang/srt/models/gemma2_reward.py | 1 - python/sglang/srt/models/gpt2.py | 15 +- python/sglang/srt/models/gpt_bigcode.py | 26 +- python/sglang/srt/models/grok.py | 4 +- python/sglang/srt/models/internlm2.py | 3 +- python/sglang/srt/models/internlm2_reward.py | 1 - python/sglang/srt/models/llama.py | 3 +- .../sglang/srt/models/llama_classification.py | 3 +- python/sglang/srt/models/llama_embedding.py | 3 +- python/sglang/srt/models/llama_reward.py | 5 +- python/sglang/srt/models/llava.py | 5 +- python/sglang/srt/models/llavavid.py | 3 +- python/sglang/srt/models/minicpm.py | 3 +- python/sglang/srt/models/minicpm3.py | 17 +- python/sglang/srt/models/mixtral.py | 3 +- python/sglang/srt/models/mixtral_quant.py | 3 +- python/sglang/srt/models/mllama.py | 8 +- python/sglang/srt/models/olmo.py | 3 +- python/sglang/srt/models/olmo2.py | 1 - python/sglang/srt/models/olmoe.py | 6 +- python/sglang/srt/models/phi3_small.py | 9 +- python/sglang/srt/models/qwen.py | 3 +- python/sglang/srt/models/qwen2.py | 3 +- python/sglang/srt/models/qwen2_moe.py | 13 +- python/sglang/srt/models/qwen2_vl.py | 4 +- python/sglang/srt/models/registry.py | 99 ++ python/sglang/srt/models/stablelm.py | 3 +- .../sglang/srt/models/torch_native_llama.py | 3 +- python/sglang/srt/models/xverse.py | 4 +- python/sglang/srt/models/xverse_moe.py | 13 +- python/sglang/srt/models/yivl.py | 5 +- python/sglang/srt/server_args.py | 9 + python/sglang/srt/utils.py | 57 +- 58 files changed, 2363 insertions(+), 366 deletions(-) create mode 100644 python/sglang/srt/configs/device_config.py create mode 100644 python/sglang/srt/configs/load_config.py create mode 100644 python/sglang/srt/model_loader/__init__.py create mode 100644 python/sglang/srt/model_loader/loader.py create mode 100644 python/sglang/srt/model_loader/utils.py create mode 100644 python/sglang/srt/model_loader/weight_utils.py create mode 100644 python/sglang/srt/models/registry.py diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 9bbe9b0f1..e7a831399 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank): model_config = ModelConfig( server_args.model_path, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, ) model_runner = ModelRunner( model_config=model_config, diff --git a/python/sglang/srt/configs/device_config.py b/python/sglang/srt/configs/device_config.py new file mode 100644 index 000000000..74deb8919 --- /dev/null +++ b/python/sglang/srt/configs/device_config.py @@ -0,0 +1,17 @@ +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +class DeviceConfig: + device: Optional[torch.device] + + def __init__(self, device: str = "cuda") -> None: + if device in ["cuda", "xpu", "hpu"]: + self.device_type = device + else: + raise RuntimeError(f"Not supported device type: {device}") + self.device = torch.device(self.device_type) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py new file mode 100644 index 000000000..2b2b341fa --- /dev/null +++ b/python/sglang/srt/configs/load_config.py @@ -0,0 +1,84 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +import enum +import json +import logging +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from sglang.srt.utils import is_hip + +logger = logging.getLogger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns, + ) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f + for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}" + ) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7517657b4..596afb83e 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -15,12 +15,14 @@ import json import logging from enum import IntEnum, auto -from typing import List, Optional +from typing import List, Optional, Union +import torch from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import get_config, get_context_length -from sglang.srt.utils import get_bool_env_var +from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.utils import get_bool_env_var, is_hip logger = logging.getLogger(__name__) @@ -33,17 +35,22 @@ class AttentionArch(IntEnum): class ModelConfig: def __init__( self, - path: str, + model_path: str, trust_remote_code: bool = True, revision: Optional[str] = None, context_length: Optional[int] = None, model_override_args: Optional[dict] = None, is_embedding: Optional[bool] = None, + dtype: str = "auto", + quantization: Optional[str] = None, ) -> None: + self.model_path = model_path + self.revision = revision + self.quantization = quantization # Parse args self.model_override_args = json.loads(model_override_args) self.hf_config = get_config( - path, + model_path, trust_remote_code=trust_remote_code, revision=revision, model_override_args=self.model_override_args, @@ -56,6 +63,7 @@ class ModelConfig: ) self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) # Derive context length derived_context_len = get_context_length(self.hf_text_config) @@ -116,6 +124,8 @@ class ModelConfig: self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size + self._verify_quantization() + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -174,6 +184,86 @@ class ModelConfig: # parallel size so each GPU has at least one KV head. return max(1, total_num_kv_heads // tensor_parallel_size) + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = [ + "awq", + "gptq", + "fp8", + "compressed_tensors", + "compressed-tensors", + "fbgemm_fp8", + ] + optimized_quantization_methods = [ + "fp8", + "marlin", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed_tensors", + "compressed-tensors", + "experts_int8", + ] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization + ) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}." + ) + if is_hip() and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm." + ) + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", + self.quantization, + ) + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. @@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig): if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): # We support non-hf version of llava models, so we do not want to # read the wrong values from the unused default text_config. + # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as + # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`. + setattr(config, "torch_dtype", torch.float16) return config if hasattr(config, "text_config"): @@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig): return config +# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16." + ) + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + def is_generation_model(model_architectures: List[str], is_embedding: bool = False): # We have two ways to determine whether a model is a generative model. # 1. Check the model architectue diff --git a/python/sglang/srt/configs/qwen2vl.py b/python/sglang/srt/configs/qwen2vl.py index 4d30c741e..d4141234a 100644 --- a/python/sglang/srt/configs/qwen2vl.py +++ b/python/sglang/srt/configs/qwen2vl.py @@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig): self.attention_dropout = attention_dropout self.rope_scaling = rope_scaling - # NOTE: the following section from original transformers config - # for Qwen2-VL is commented out to address rope config loading issue - # - # if self.rope_scaling is not None and "type" in self.rope_scaling: - # if self.rope_scaling["type"] == "mrope": - # self.rope_scaling["type"] = "default" - # self.rope_scaling["rope_type"] = self.rope_scaling["type"] - # rope_config_validation(self) + # NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm. + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index ac475cf34..92b01d452 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -75,6 +75,8 @@ def get_config( if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) + # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. + setattr(config, "_name_or_path", model) if model_override_args: config.update(model_override_args) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 095164e1a..f69058ff3 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "Fp8LinearMethod", "MarlinLinearMethod", "GPTQLinearMethod", + "QQQLinearMethod", ] diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 9f21df778..839d10222 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.loader import DefaultModelLoader from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -40,6 +39,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_loader.loader import DefaultModelLoader class BaseLayerWithLoRA(nn.Module): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 153d6f6f5..3714f19b6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -147,9 +147,12 @@ class Scheduler: self.model_config = ModelConfig( server_args.model_path, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, ) self.is_generation = self.model_config.is_generation diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d1b5fa37a..56e01528a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -109,9 +109,12 @@ class TokenizerManager: self.model_config = ModelConfig( server_args.model_path, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, ) self.is_generation = self.model_config.is_generation diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 43d82c1a0..3aa06b4b8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -52,9 +52,12 @@ class TpModelWorker: self.model_config = ModelConfig( server_args.model_path, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, ) self.model_runner = ModelRunner( model_config=self.model_config, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3fffa2047..73dec4a9c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -14,22 +14,12 @@ """ModelRunner runs the forward passes of the models.""" import gc -import importlib -import importlib.resources -import inspect import json import logging -import pkgutil -import time -from functools import lru_cache -from tokenize import tabsize -from typing import Any, Optional, Type, Union +from typing import Optional import torch import torch.distributed as dist -import torch.nn as nn -from vllm.config import DeviceConfig, LoadConfig -from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( get_tp_group, init_distributed_environment, @@ -37,9 +27,9 @@ from vllm.distributed import ( set_custom_all_reduce, ) from vllm.distributed.parallel_state import in_the_same_node_as -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import ModelRegistry +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend @@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import ( ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader import get_model from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, init_custom_process_group, is_hip, monkey_patch_vllm_gguf_config, - monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) @@ -228,49 +217,6 @@ class ModelRunner: return min_per_gpu_memory - def setup_model(self): - try: - from vllm.config import VllmConfig - - vllm_config = VllmConfig() - vllm_config.model_config = self.vllm_model_config - vllm_config.load_config = self.load_config - vllm_config.device_config = DeviceConfig(self.device) - vllm_config.quant_config = VllmConfig._get_quantization_config( - vllm_config.model_config, vllm_config.load_config - ) - return get_model(vllm_config=vllm_config) - except ImportError: - pass - - return get_model( - model_config=self.vllm_model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - parallel_config=None, - scheduler_config=None, - lora_config=None, - cache_config=None, - ) - - def get_model_config_params(self): - sig = inspect.signature(VllmModelConfig.__init__) - params = { - "model": self.server_args.model_path, - "quantization": self.server_args.quantization, - "tokenizer": None, - "tokenizer_mode": None, - "trust_remote_code": self.server_args.trust_remote_code, - "dtype": self.server_args.dtype, - "seed": self.server_args.random_seed, - "skip_tokenizer_init": True, - } - - if "task" in sig.parameters: - params["task"] = "" - - return params - def load_model(self): logger.info( f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" @@ -284,6 +230,7 @@ class ModelRunner: "Compute capability below sm80. Use float16 due to lack of bfloat16 support." ) self.server_args.dtype = "float16" + self.model_config.dtype = torch.float16 if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") @@ -292,23 +239,21 @@ class ModelRunner: load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) - monkey_patch_vllm_model_config() + if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() - self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) - if self.model_config.model_override_args is not None: - self.vllm_model_config.hf_config.update( - self.model_config.model_override_args - ) - - self.model = self.setup_model() + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) self.sliding_window_size = ( self.model.get_attention_sliding_window_size() if hasattr(self.model, "get_attention_sliding_window_size") else None ) - self.dtype = self.vllm_model_config.dtype + self.dtype = self.model_config.dtype logger.info( f"Load weight end. " @@ -319,12 +264,12 @@ class ModelRunner: def update_weights_from_disk(self, model_path: str, load_format: str): """Update engine weights online from disk.""" - from vllm.model_executor.model_loader.loader import ( + from sglang.srt.model_loader.loader import ( DefaultModelLoader, device_loading_context, get_model_loader, ) - from vllm.model_executor.model_loader.utils import set_default_torch_dtype + from sglang.srt.model_loader.utils import set_default_torch_dtype logger.info( f"Update engine weights online from disk begin. " @@ -332,15 +277,7 @@ class ModelRunner: ) target_device = torch.device(self.device) - - try: - model_config_params = self.get_model_config_params() - model_config_params["model"] = model_path - vllm_model_config = VllmModelConfig(**model_config_params) - except Exception as e: - message = f"Failed to load model config: {e}." - return False, message - + self.model_config.model_path = model_path load_config = LoadConfig(load_format=load_format) # Only support vllm DefaultModelLoader for now @@ -352,7 +289,7 @@ class ModelRunner: def get_weight_iter(config): iter = loader._get_weights_iterator( DefaultModelLoader.Source( - config.model, + config.model_path, revision=config.revision, fall_back_to_pt=getattr( self.model, "fall_back_to_pt_during_load", True @@ -370,9 +307,9 @@ class ModelRunner: quant_method.process_weights_after_loading(module) return model - with set_default_torch_dtype(vllm_model_config.dtype): + with set_default_torch_dtype(self.model_config.dtype): try: - iter = get_weight_iter(vllm_model_config) + iter = get_weight_iter(self.model_config) except Exception as e: message = f"Failed to get weights iterator: {e}." return False, message @@ -384,16 +321,14 @@ class ModelRunner: ) del iter gc.collect() - iter = get_weight_iter(self.vllm_model_config) + iter = get_weight_iter(self.model_config) self.model = model_load_weights(self.model, iter) return False, message self.model = model self.server_args.model_path = model_path self.server_args.load_format = load_format - self.vllm_model_config = vllm_model_config self.load_config = load_config - self.model_config.path = model_path logger.info("Update weights end.") return True, "Succeeded to update model weights." @@ -794,55 +729,3 @@ class ModelRunner: if rope_scaling is None: return False return rope_scaling.get("type", None) == "mrope" - - -@lru_cache() -def import_model_classes(): - model_arch_name_to_cls = {} - package_name = "sglang.srt.models" - package = importlib.import_module(package_name) - for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): - if not ispkg: - try: - module = importlib.import_module(name) - except Exception as e: - logger.warning(f"Ignore import error when loading {name}. {e}") - if crash_on_warnings(): - raise ValueError(f"Ignore import error when loading {name}. {e}") - continue - if hasattr(module, "EntryClass"): - entry = module.EntryClass - if isinstance( - entry, list - ): # To support multiple model classes in one module - for tmp in entry: - assert ( - tmp.__name__ not in model_arch_name_to_cls - ), f"Duplicated model implementation for {tmp.__name__}" - model_arch_name_to_cls[tmp.__name__] = tmp - else: - assert ( - entry.__name__ not in model_arch_name_to_cls - ), f"Duplicated model implementation for {entry.__name__}" - model_arch_name_to_cls[entry.__name__] = entry - - return model_arch_name_to_cls - - -def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: - model_arch_name_to_cls = import_model_classes() - - if model_arch not in model_arch_name_to_cls: - raise ValueError( - f"Unsupported architectures: {model_arch}. " - f"Supported list: {list(model_arch_name_to_cls.keys())}" - ) - return model_arch_name_to_cls[model_arch] - - -# Monkey patch model loader -setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) -setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False) -setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False) -setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False) -setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False) diff --git a/python/sglang/srt/model_loader/__init__.py b/python/sglang/srt/model_loader/__init__.py new file mode 100644 index 000000000..fa2386e3a --- /dev/null +++ b/python/sglang/srt/model_loader/__init__.py @@ -0,0 +1,34 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py + +from torch import nn + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader +from sglang.srt.model_loader.utils import ( + get_architecture_class_name, + get_model_architecture, +) + + +def get_model( + *, + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, +) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model( + model_config=model_config, + device_config=device_config, + ) + + +__all__ = [ + "get_model", + "get_model_loader", + "BaseModelLoader", + "get_architecture_class_name", + "get_model_architecture", +] diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py new file mode 100644 index 000000000..e0b03d771 --- /dev/null +++ b/python/sglang/srt/model_loader/loader.py @@ -0,0 +1,1139 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py + +# ruff: noqa: SIM117 +import collections +import dataclasses +import fnmatch +import glob +import json +import logging +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast + +import gguf +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig, LoadFormat +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_loader.utils import ( + get_model_architecture, + set_default_torch_dtype, +) +from sglang.srt.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + get_gguf_extra_tensor_names, + get_quant_config, + gguf_quant_weights_iterator, + initialize_dummy_weights, + np_cache_weights_iterator, + pt_weights_iterator, + safetensors_weights_iterator, +) +from sglang.srt.utils import ( + get_device_capability, + is_pin_memory_available, + set_weight_attrs, +) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = logging.getLogger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig +) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + major, minor = get_device_capability() + + if major is not None and minor is not None: + assert 0 <= minor < 10 + capability = major * 10 + minor + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + return quant_config + return None + + +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class, _ = get_model_architecture(model_config) + quant_config = _get_quantization_config(model_config, load_config) + return model_class( + config=model_config.hf_config, + quant_config=quant_config, + ) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str] + ) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if "SGLANG_USE_MODELSCOPE" in os.environ: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = ( + self._maybe_download_from_modelscope(model_name_or_path, revision) + or model_name_or_path + ) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt + ) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + ) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model_path, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()) + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights( + model_config.model_path, model_config.revision, fall_back_to_pt=True + ) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model( + model_config, + self.load_config, + ) + + model.load_weights(self._get_all_weights(model_config, model)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, + self.load_config, + ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_state.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ( + {} + if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy() + ) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}" + ) + + @staticmethod + def _filter_subtensors(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list) + ) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model_path, model_config.revision) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from safetensors.torch import safe_open + from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights( + model_config.model_path, model_config.revision + ) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!" + ) + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + default_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ".fc1.", + ".fc2.", + ".dense.", + ".query_key_value.", + ".qkv_proj.", + ".dense_h_to_4h.", + ".dense_4h_to_h.", + ".out_proj.", + ] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + self.target_modules = [] + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download( + repo_id=qlora_adapter, filename=file + ) + break + + if not config_file_path: + raise ValueError(f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob(os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError(f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision + ) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.44.0": + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.44.0." + ) + except ImportError as err: + raise ImportError( + "Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " + "bitsandbytes quantizer." + ) from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision + ) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return ( + self._quantized_8bit_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + else: + return ( + self._quantized_4bit_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + + return ( + self._unquantized_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + + def _quantized_8bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if not weight_name.endswith((".weight", ".bias")): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith((".weight", ".bias")): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if not weight_name.endswith((".weight", ".bias")): + continue + + if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( + f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict + ): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if any( + target_module in weight_name for target_module in self.target_modules + ) and weight_name.endswith(".weight"): + weight_name = weight_name.replace(".weight", ".qweight") + + if any( + module in weight_name + for module in self.column_parallel_weights_modules + ): + + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, compress_statistics=True, quant_type="nf4" + ) + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}." + ) + + if not hasattr(model, "bitsandbytes_stacked_params_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet." + ) + + if len(self.target_modules) == 0: + if hasattr(model, "default_bitsandbytes_target_modules"): + self.target_modules = model.default_bitsandbytes_target_modules + else: + self.target_modules = self.default_target_modules + + if hasattr(model, "column_parallel_weights_modules"): + self.column_parallel_weights_modules = model.column_parallel_weights_modules + else: + self.column_parallel_weights_modules = [] + + self.model_type = type(model).__name__ + + logger.info( + "Loading weights with BitsAndBytes quantization. " " May take a while ..." + ) + + quant_config = getattr(model_config.hf_config, "quantization_config", None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization" + ) + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP." + ) + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get("load_in_8bit", False) + + qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator( + model_config.model_path, model_config.revision, pre_quant, load_8bit + ) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace(shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model." + ) + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[ + non_stacked_param_name + ] + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError(f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod(quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)} + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model_path, model_config.revision) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, + self.load_config, + ) + + self._load_weights(model_config, model) + + return model.eval() + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config(config) + state_dict = dummy_model.state_dict() + + gguf_to_hf_name_map = {} + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model_path) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + + local_model_path = self._prepare_weights(model_config.model_path) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map + ): + model_config.hf_config.update({"tie_word_embeddings": True}) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map) + ) + return model + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py new file mode 100644 index 000000000..daad1e67f --- /dev/null +++ b/python/sglang/srt/model_loader/utils.py @@ -0,0 +1,41 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py + +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn + +from sglang.srt.configs.model_config import ModelConfig + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + from sglang.srt.models.registry import ModelRegistry + + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"] + + if ( + model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures + ): + architectures = ["QuantMixtralForCausalLM"] + + return ModelRegistry.resolve_model_cls(architectures) + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py new file mode 100644 index 000000000..13b323b5d --- /dev/null +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -0,0 +1,640 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py + +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import logging +import os +import tempfile +from collections import defaultdict +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import filelock +import gguf +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm +from vllm.distributed import get_tensor_model_parallel_rank + +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config +from sglang.srt.utils import print_warning_once + +logger = logging.getLogger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer""" + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """ + ) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: + + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + else: + model_name_or_path = model_config.model_path + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError(f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}" + ) + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}." + ) + + return quant_cls.from_config(config) + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=index_file, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have index_file. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", index_file) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: List[str], hf_folder: str, index_file: str +) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def np_cache_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names: List[str] = [] + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file) as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + yield from state.items() + del state + torch.cuda.empty_cache() + + +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> List[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader( + param: torch.Tensor, loaded_weight: torch.Tensor +) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] +) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + print_warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded." + ) + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded." + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index d3b0fd9ae..3bd60c25d 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module): self, config: PretrainedConfig, position_embedding: str, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", cache_config, quant_config) + super().__init__(config, "ROPE", quant_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", cache_config, quant_config) + super().__init__(config, "ALIBI", quant_config) EntryClass = [BaichuanForCausalLM] diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index ced6859c7..9c3bc2ee9 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -23,7 +23,6 @@ from torch import nn from torch.nn import LayerNorm from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul @@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader LoraConfig = None @@ -50,7 +50,6 @@ class GLMAttention(nn.Module): self, config, layer_id: int = 0, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -186,7 +185,6 @@ class GLMBlock(nn.Module): self, config, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -203,7 +201,7 @@ class GLMBlock(nn.Module): ) # Self attention. - self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) + self.self_attention = GLMAttention(config, layer_id, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -258,7 +256,6 @@ class GLMTransformer(nn.Module): def __init__( self, config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -269,10 +266,7 @@ class GLMTransformer(nn.Module): # Transformer layers. self.layers = nn.ModuleList( - [ - GLMBlock(config, i, cache_config, quant_config) - for i in range(self.num_layers) - ] + [GLMBlock(config, i, quant_config) for i in range(self.num_layers)] ) if self.post_layer_norm: @@ -306,7 +300,6 @@ class ChatGLMM(nn.Module): def __init__( self, config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -318,7 +311,7 @@ class ChatGLMM(nn.Module): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, cache_config, quant_config) + self.encoder = GLMTransformer(config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoraConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) - self.transformer = ChatGLMM(config, cache_config, quant_config) + self.transformer = ChatGLMM(config, quant_config) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 8769d49db..a758e4f56 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -49,7 +49,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import set_weight_attrs @@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index e9b4ff141..45561d1db 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -25,7 +25,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.fused_moe_triton import fused_moe @@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import set_weight_attrs @@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module): self, config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ): super().__init__() self.config = config diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 43dfc50a4..ce1b152fb 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -27,7 +27,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import fused_moe @@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class DeepseekMLP(nn.Module): @@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, ) if ( @@ -330,7 +327,6 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -343,9 +339,7 @@ class DeepseekModel(nn.Module): ) self.layers = nn.ModuleList( [ - DeepseekDecoderLayer( - config, layer_id, cache_config, quant_config=quant_config - ) + DeepseekDecoderLayer(config, layer_id, quant_config=quant_config) for layer_id in range(config.num_hidden_layers) ] ) @@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, cache_config, quant_config) + self.model = DeepseekModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 55a458c20..424f86aec 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -28,7 +28,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE @@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): @@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, layer_id=None, ) -> None: @@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, layer_id=None, use_dp=False, @@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, layer_id=layer_id, use_dp=self.enable_dp_attention, @@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, layer_id=layer_id, ) @@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module): DeepseekV2DecoderLayer( config, layer_id, - cache_config=cache_config, quant_config=quant_config, ) for layer_id in range(config.num_hidden_layers) @@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekV2Model(config, cache_config, quant_config) + self.model = DeepseekV2Model(config, quant_config) if global_server_args_dict["enable_dp_attention"]: self.lm_head = ReplicatedLinear( config.hidden_size, diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 8c244419f..536c253c3 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -22,7 +22,6 @@ import torch from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class ExaoneGatedMLP(nn.Module): @@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index f6d301546..10949a2f5 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GemmaMLP(nn.Module): @@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - cache_config=None, ) -> None: - del lora_config # Unused. super().__init__() self.config = config self.quant_config = quant_config diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 104205648..dbca72688 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( @@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module): head_dim: int, max_position_embeddings: int, rope_theta: float, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module): self, layer_id: int, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module): head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, - cache_config=cache_config, quant_config=quant_config, ) self.hidden_size = config.hidden_size @@ -258,7 +252,6 @@ class Gemma2Model(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -273,7 +266,6 @@ class Gemma2Model(nn.Module): lambda idx, prefix: Gemma2DecoderLayer( layer_id=idx, config=config, - cache_config=cache_config, quant_config=quant_config, ), prefix="", @@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, ) -> None: - del lora_config # Unused. super().__init__() self.config = config self.quant_config = quant_config - self.model = Gemma2Model(config, cache_config, quant_config) + self.model = Gemma2Model(config, quant_config) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/gemma2_reward.py b/python/sglang/srt/models/gemma2_reward.py index 848eb2c02..e5c2fc07a 100644 --- a/python/sglang/srt/models/gemma2_reward.py +++ b/python/sglang/srt/models/gemma2_reward.py @@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module): self, config: Gemma2Config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 6fbfe9edd..144ad8bbf 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPT2Config -from vllm.config import CacheConfig from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader # from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( @@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GPT2Attention(nn.Module): @@ -47,7 +46,6 @@ class GPT2Attention(nn.Module): self, layer_id: int, config: GPT2Config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -140,7 +138,6 @@ class GPT2Block(nn.Module): self, layer_id: int, config: GPT2Config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -150,7 +147,7 @@ class GPT2Block(nn.Module): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention( - layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn" + layer_id, config, quant_config, prefix=f"{prefix}.attn" ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") @@ -182,7 +179,6 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -196,7 +192,7 @@ class GPT2Model(nn.Module): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [ - GPT2Block(i, config, cache_config, quant_config) + GPT2Block(i, config, quant_config) for i in range(config.num_hidden_layers) ] ) @@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model( - config, cache_config, quant_config, prefix="transformer" - ) + self.transformer = GPT2Model(config, quant_config, prefix="transformer") self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 5af127320..f2f5ebd52 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( @@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GPTBigCodeAttention(nn.Module): @@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module): self, layer_id: int, config: GPTBigCodeConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module): self, layer_id: int, config: GPTBigCodeConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module): inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config) + self.attn = GPTBigCodeAttention(layer_id, config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) + lora_vocab = 0 self.vocab_size = config.vocab_size + lora_vocab self.wte = VocabParallelEmbedding( self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size @@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [ - GPTBigCodeBlock(i, config, cache_config, quant_config) + GPTBigCodeBlock(i, config, quant_config) for i in range(config.num_hidden_layers) ] ) @@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.lora_config = lora_config self.quant_config = quant_config - self.transformer = GPTBigCodeModel( - config, cache_config, quant_config, lora_config - ) + self.transformer = GPTBigCodeModel(config, quant_config) self.lm_head = self.transformer.wte self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index d5c303d13..956f73b14 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -24,7 +24,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm @@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.loader import DefaultModelLoader +from sglang.srt.model_loader.weight_utils import default_weight_loader class Grok1MoE(nn.Module): @@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index d217fd71f..0a737c138 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -21,7 +21,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class InternLM2MLP(nn.Module): @@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/internlm2_reward.py b/python/sglang/srt/models/internlm2_reward.py index 78831599d..d5fe9c059 100644 --- a/python/sglang/srt/models/internlm2_reward.py +++ b/python/sglang/srt/models/internlm2_reward.py @@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 62ad0d2a0..61409a9ea 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -24,7 +24,6 @@ from torch import nn from transformers import LlamaConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index c22b68d11..038732476 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index da43d03fc..34b316dda 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -3,10 +3,10 @@ from typing import Iterable, Tuple import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.model_executor.model_runner import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaModel @@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module): self, config: LlamaConfig, quant_config=None, - cache_config=None, ) -> None: super().__init__() self.model = LlamaModel(config, quant_config=quant_config) diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 5eb2daae6..dcde8b468 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -21,6 +21,7 @@ from transformers import LlamaConfig from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config @@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: - super().__init__(config, quant_config, cache_config) + super().__init__(config, quant_config) self.weights = self.Weights(config.hidden_size, self.num_labels) @torch.no_grad() diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index eb1784145..4c62dbb25 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -29,7 +29,6 @@ from transformers import ( SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import ImageInputs @@ -39,6 +38,7 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM @@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() @@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() @@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index c06ef8769..7b5f236a5 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -20,11 +20,11 @@ import torch from torch import nn from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM @@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 0d668fe5d..3482a8281 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -20,7 +20,6 @@ import torch from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MiniCPMMLP(nn.Module): @@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index e6bf118ed..b0c93274e 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): @@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, layer_id=None, ) -> None: @@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, layer_id=None, ) -> None: @@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, layer_id=layer_id, ) @@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, layer_id=layer_id, ) @@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module): ) self.layers = nn.ModuleList( [ - MiniCPM3DecoderLayer( - config, i, cache_config=cache_config, quant_config=quant_config - ) + MiniCPM3DecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module): self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config - self.model = MiniCPM3Model( - config, cache_config=cache_config, quant_config=quant_config - ) + self.model = MiniCPM3Model(config, quant_config=quant_config) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b2e895f56..b222387a7 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -23,7 +23,6 @@ from torch import nn from transformers import MixtralConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm @@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MixtralMoE(nn.Module): @@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 8dba2b722..e5f49f566 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -29,7 +29,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MixtralMLP(nn.Module): @@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 2a0cf4ea3..019d21c20 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import ( _prepare_aspect_ratio_attention_mask, ) from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.layernorm import RMSNorm @@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP @@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module): self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], - cache_config=None, ): super().__init__() self.padding_id = config.pad_token_id @@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module): self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], - cache_config=None, ): super().__init__() self.vocab_size = config.vocab_size - self.model = MllamaTextModel(config, cache_config, quant_config) + self.model = MllamaTextModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, @@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module): self, config: config_mllama.MllamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ): super().__init__() self.vocab_size = config.text_config.vocab_size @@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module): self.vision_model = MllamaVisionModel(config.vision_config) self.language_model = MllamaForCausalLM( config.text_config, - cache_config=cache_config, quant_config=quant_config, ) self.multi_modal_projector = nn.Linear( diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 2ef6532ce..1cfa27309 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -22,7 +22,6 @@ from torch import nn from transformers import OlmoConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module): def __init__( self, config: OlmoConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index d73a6d5a3..0944b5720 100755 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 549e2d032..859f4135c 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.utils import print_warning_once from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE @@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import make_layers, print_warning_once class OlmoeMoE(nn.Module): @@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index e310dfcea..634033077 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -7,8 +7,6 @@ from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import make_layers from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import make_layers @torch.jit.script @@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module): super().__init__() self.config = config - cache_config = None self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) @@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Phi3SmallDecoderLayer( - config, int(prefix.split(".")[-1]), cache_config, quant_config + config, int(prefix.split(".")[-1]), quant_config ), prefix=f"{prefix}.layers", ) @@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module): self, config: Phi3Config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ): super().__init__() diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index fb4b67ff5..5492a3e12 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -22,7 +22,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class QWenMLP(nn.Module): @@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ): super().__init__() self.config = config diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4c8ddd4b9..9383fde4d 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -22,7 +22,6 @@ import torch from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers Qwen2Config = None @@ -271,7 +271,6 @@ class Qwen2ForCausalLM(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 256993269..0094cb8c3 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -27,7 +27,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE @@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class Qwen2MoeMLP(nn.Module): @@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, ) @@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module): ) self.layers = nn.ModuleList( [ - Qwen2MoeDecoderLayer( - config, layer_id, cache_config, quant_config=quant_config - ) + Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config) for layer_id in range(config.num_hidden_layers) ] ) @@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] - self.model = Qwen2MoeModel(config, cache_config, quant_config) + self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config ) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 155bde015..2e9ec9d8f 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -30,12 +30,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor @@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model logger = init_logger(__name__) @@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): def __init__( self, config: Qwen2VLConfig, - cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py new file mode 100644 index 000000000..fc63bf125 --- /dev/null +++ b/python/sglang/srt/models/registry.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py + +import importlib +import logging +import pkgutil +from dataclasses import dataclass, field +from functools import lru_cache +from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union + +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict) + + def get_supported_archs(self) -> AbstractSet[str]: + return self.models.keys() + + def _raise_for_unsupported(self, architectures: List[str]): + all_supported_archs = self.get_supported_archs() + + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details." + ) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) + + def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in self.models: + return None + + return self.models[model_arch] + + def _normalize_archs( + self, + architectures: Union[str, List[str]], + ) -> List[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return architectures + + def resolve_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[Type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + try: + module = importlib.import_module(name) + except Exception as e: + logger.warning(f"Ignore import error when loading {name}. " f"{e}") + continue + if hasattr(module, "EntryClass"): + entry = module.EntryClass + if isinstance( + entry, list + ): # To support multiple model classes in one module + for tmp in entry: + assert ( + tmp.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {tmp.__name__}" + model_arch_name_to_cls[tmp.__name__] = tmp + else: + assert ( + entry.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {entry.__name__}" + model_arch_name_to_cls[entry.__name__] = entry + + return model_arch_name_to_cls + + +ModelRegistry = _ModelRegistry(import_model_classes()) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 38f2be13a..079d54e3c 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -26,7 +26,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class StablelmMLP(nn.Module): @@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 68982eebf..25e555484 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -52,7 +52,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -66,6 +65,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -388,7 +388,6 @@ class TorchNativeLlamaForCausalLM(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 42f51a7fa..e65514215 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.model_runner import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class XverseMLP(nn.Module): @@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, - efficient_weight_load=False, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 3a8b9a9e4..e1f328875 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor @@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class XverseMLP(nn.Module): @@ -181,7 +181,6 @@ class XverseAttention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - cache_config=cache_config, quant_config=quant_config, ) if config.num_experts is not None: @@ -326,7 +323,6 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -339,9 +335,7 @@ class XverseModel(nn.Module): ) self.layers = nn.ModuleList( [ - XverseDecoderLayer( - config, layer_id, cache_config, quant_config=quant_config - ) + XverseDecoderLayer(config, layer_id, quant_config=quant_config) for layer_id in range(config.num_hidden_layers) ] ) @@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config=None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, cache_config, quant_config) + self.model = XverseModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config ) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 6f1610e52..97ee5946c 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llava import LlavaLlamaForCausalLM @@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: - super().__init__(config, quant_config, cache_config) + super().__init__(config, quant_config) self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.vision_tower_subfolder = self.config.mm_vision_tower.replace( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 283936f05..37ad6cfc5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -50,6 +50,7 @@ class ServerArgs: served_model_name: Optional[str] = None chat_template: Optional[str] = None is_embedding: bool = False + revision: Optional[str] = None # Port host: str = "127.0.0.1" @@ -341,6 +342,14 @@ class ServerArgs: action="store_true", help="Whether to use a CausalLM as an embedding model.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="The specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) # Memory and scheduling parser.add_argument( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d0bb767d7..c19d521a0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -430,16 +430,12 @@ def suppress_other_loggers(): from vllm.logger import logger as vllm_default_logger vllm_default_logger.setLevel(logging.WARN) - logging.getLogger("vllm.config").setLevel(logging.ERROR) logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel( logging.WARN ) logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel( logging.WARN ) - logging.getLogger("vllm.selector").setLevel(logging.WARN) - logging.getLogger("vllm.utils").setLevel(logging.ERROR) - logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR) warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" @@ -492,27 +488,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def monkey_patch_vllm_model_config(): - from vllm.config import ModelConfig - - if not hasattr(ModelConfig, "_resolve_task"): - return - - def _resolve_task( - self, - task_option, - hf_config, - ): - supported_tasks = { - "generate": True, - "embedding": False, - } - selected_task = "generate" - return supported_tasks, selected_task - - setattr(ModelConfig, "_resolve_task", _resolve_task) - - def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ Monkey patch the slow p2p access check in vllm. @@ -1041,6 +1016,11 @@ def crash_on_warnings(): return get_bool_env_var("SGLANG_IS_IN_CI") +def print_warning_once(msg: str) -> None: + # Set the stacklevel to 2 to print the caller's line info + logger.warning(msg, stacklevel=2) + + def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_name(device_id) @@ -1055,6 +1035,33 @@ def get_device_name(device_id: int = 0) -> str: return torch.hpu.get_device_name(device_id) +def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + major, minor = None, None + if hasattr(torch, "cuda") and torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability(device_id) + + if hasattr(torch, "hip") and torch.hip.is_available(): + major, minor = torch.cuda.get_device_capability(device_id) + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( + "." + ) + major, minor = int(major), int(minor) + + # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now. + # Update this once the support is available. + if hasattr(torch, "hpu") and torch.hpu.is_available(): + try: + major, minor = torch.hpu.get_device_capability(device_id) + except Exception as e: + raise RuntimeError( + f"An error occurred while getting device capability of hpu: {e}." + ) from e + + return major, minor + + sglang_lib = Library("sglang", "FRAGMENT") # noqa