# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os from collections.abc import Generator import gguf import regex as re import torch import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, ) from vllm.model_executor.model_loader.weight_utils import ( download_gguf, get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, ) from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) class GGUFModelLoader(BaseModelLoader): """ Model loader that can load GGUF files. This is useful for loading models that are quantized with GGUF and saved in the GGUF format. This loader supports loading both full models and sharded models. """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: raise ValueError( f"Model loader extra config is not supported for " f"load format {load_config.load_format}" ) def _prepare_weights(self, model_config: ModelConfig): model_name_or_path = model_config.model if os.path.isfile(model_name_or_path): return model_name_or_path # for raw HTTPS link if model_name_or_path.startswith( ("http://", "https://") ) and model_name_or_path.endswith(".gguf"): return hf_hub_download(url=model_name_or_path) # repo id/filename.gguf if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): repo_id, filename = model_name_or_path.rsplit("/", 1) return hf_hub_download(repo_id=repo_id, filename=filename) # repo_id:quant_type elif "/" in model_name_or_path and ":" in model_name_or_path: repo_id, quant_type = model_name_or_path.rsplit(":", 1) return download_gguf( repo_id, quant_type, cache_dir=self.load_config.download_dir, revision=model_config.revision, ignore_patterns=self.load_config.ignore_patterns, ) raise ValueError( f"Unrecognised GGUF reference: {model_name_or_path} " "(expected local file, raw URL, /.gguf, " "or :)" ) def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: `blk.N.BB.weight` and `blk.N.BB.bias` where N signifies the block number of a layer, and BB signifies the attention/mlp layer components. See "Standardized tensor names" in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. """ config = model_config.hf_config # Get text config to handle both nested (multimodal) and flat # (text-only) config structures. For multimodal models like # Gemma3Config, this returns config.text_config. For text-only # models, this returns config itself. text_config = config.get_text_config() model_type = config.model_type is_multimodal = ( hasattr(config, "vision_config") and config.vision_config is not None ) gguf_to_hf_name_map = {} sideload_params: list[re.Pattern] = [] # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" if model_type == "gemma3_text": # Gemma3 models use "gemma3_text" in HuggingFace but # "gemma3" in GGUF architecture naming model_type = "gemma3" if model_type in ("deepseek_v3", "deepseek_v2"): model_type = "deepseek2" # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = ( f"model.layers.{idx}.mlp.gate.e_score_correction_bias" ) gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.down_proj.weight" ) gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" ) gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.up_proj.weight" ) sideload_params.append( re.compile( f"model\\.layers\\.{idx}" r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight" ) ) if model_type in ("qwen2_moe", "qwen3_moe"): model_type = model_type.replace("_", "") # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.down_proj.weight" ) gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" ) gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.up_proj.weight" ) sideload_params.append( re.compile( f"model\\.layers\\.{idx}" r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight" ) ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: arch = key break if arch is None: raise RuntimeError(f"Unknown gguf model_type: {model_type}") text_num_layers = text_config.num_hidden_layers text_name_map = gguf.get_tensor_name_map(arch, text_num_layers) if is_multimodal: mm_proj_arch = gguf.MODEL_ARCH.MMPROJ vision_num_layers = config.vision_config.num_hidden_layers vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers) else: vision_name_map = None # Create dummy model to extract parameter names # For multimodal: use AutoModelForImageTextToText to get # language + vision + projector params # For text-only: use AutoModelForCausalLM to get language model params auto_cls = ( AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM ) with torch.device("meta"): dummy_model = auto_cls.from_config( config, trust_remote_code=model_config.trust_remote_code ) state_dict = dummy_model.state_dict() if hf_checkpoint_map := getattr( dummy_model, "_checkpoint_conversion_mapping", None ): def revert_hf_rename(name: str) -> str: for original_name, hf_name in hf_checkpoint_map.items(): if hf_name in name: name = name.replace(hf_name, original_name).lstrip("^") return name state_dict = { revert_hf_rename(name): tensor for name, tensor in state_dict.items() } def find_hf_name_in_tensor_map(hf_name: str) -> str | None: """ Map HuggingFace parameter name to GGUF tensor name. This function handles the mismatch between HF parameter naming conventions and gguf-py's expected format: 1. Strips 'model.' prefix (common in multimodal models) 2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility) 3. Searches vision_name_map for multimodal parameters 4. Falls back to text_name_map for language model parameters Args: hf_name: Full HuggingFace parameter name (e.g., 'model.multi_modal_projector.mm_soft_emb_norm.weight') Returns: GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight') or None if no mapping found """ # Strip 'language_model.' prefix for multimodal models - gguf-py # tensor mappings expect parameter names without this prefix. # Note: 'model.' prefix should be KEPT for text-only models as # gguf-py expects it. if hf_name.startswith("language_model."): hf_name = hf_name[15:] # Remove 'language_model.' # Parse parameter name and suffix if hf_name.endswith((".weight", ".bias")): base_name, suffix = hf_name.rsplit(".", 1) else: base_name, suffix = hf_name, "" # Handle '_weight' suffix (Gemma3 naming: parameter ends with # '_weight' instead of '.weight') if base_name.endswith("_weight"): base_name = base_name[:-7] # Remove '_weight' suffix = "weight" gguf_name = None # Priority 1: Search vision/projector parameters for multimodal models if vision_name_map is not None: gguf_name = vision_name_map.get_name(base_name) # Priority 2: Search text backbone parameters if gguf_name is None: gguf_name = text_name_map.get_name(base_name) if gguf_name is None: return None return gguf_name + "." + suffix # Build mapping and track unmapped parameters unmapped_params = [] for hf_name in state_dict: gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name) # Track mapping success if gguf_name_with_suffix is not None: gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name) elif hf_name not in gguf_to_hf_name_map.values(): # Parameter not in manual overrides either unmapped_params.append(hf_name) # All parameters (except those initialized by other means) must be mapped: # both vision/projector and backbone if unmapped_params: unmapped_params = list( filter( lambda x: not any(re.fullmatch(p, x) for p in sideload_params), unmapped_params, ) ) if unmapped_params: raise RuntimeError( f"Failed to map GGUF parameters " f"({len(unmapped_params)}): " f"{unmapped_params}" ) return gguf_to_hf_name_map def _get_gguf_weight_type( self, model_config: ModelConfig, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str], ) -> dict[str, str]: weight_type_map = get_gguf_weight_type_map( model_name_or_path, gguf_to_hf_name_map ) is_multimodal = hasattr(model_config.hf_config, "vision_config") if is_multimodal: mmproj_file = detect_gguf_multimodal(model_name_or_path) assert mmproj_file is not None, ( "Could not find mm_proj file for multimodal GGUF model" ) logger.info("Loading extra mm_proj weights from %s...", mmproj_file) mm_proj_weight_type_map = get_gguf_weight_type_map( mmproj_file, gguf_to_hf_name_map ) weight_type_map.update(mm_proj_weight_type_map) return weight_type_map def _get_weights_iterator( self, model_config: ModelConfig, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str], ) -> Generator[tuple[str, torch.Tensor], None, None]: """ Iterate over GGUF model weights, loading from both main model file and mmproj.gguf for multimodal Gemma3 models. For Gemma3 multimodal GGUF models: - Main file (gemma-3-*.gguf): Language model weights (model.*) - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*) Yields: Tuples of (parameter_name, tensor) for all model weights """ hf_config = model_config.hf_config is_multimodal = hasattr(hf_config, "vision_config") if is_multimodal: # Load mm_proj (mm_encoder + projector) for multimodal weights mmproj_file = detect_gguf_multimodal(model_name_or_path) assert mmproj_file is not None, ( "Could not find mm_proj file for multimodal GGUF model" ) yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map) yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: local_model_path = self._prepare_weights(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( self._get_weights_iterator(model_config, local_model_path, gguf_weights_map) ) def load_model( self, vllm_config: VllmConfig, model_config: ModelConfig ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( local_model_path, gguf_weights_map ): model_config.hf_config.update({"tie_word_embeddings": True}) weight_type_map = self._get_gguf_weight_type( model_config, local_model_path, gguf_weights_map ) # filter out unquantized modules to skip unquant_names = [ name.removesuffix(".weight") for name, weight_type in weight_type_map.items() if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight") ] logger.debug( "GGUF unquantized modules: %s", unquant_names, ) vllm_config.quant_config.unquantized_modules.extend(unquant_names) target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) return model