diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 748069fc2..d3ed96fe0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ from vllm.distributed import ( init_distributed_environment, initialize_model_parallel, ) +from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config @@ -168,15 +169,6 @@ class ModelRunner: if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) - if ( - self.server_args.efficient_weight_load - and "llama" in self.server_args.model_path.lower() - and self.server_args.quantization == "fp8" - ): - from sglang.srt.model_loader.model_loader import get_model - else: - from vllm.model_executor.model_loader import get_model - self.model = get_model( model_config=vllm_model_config, device_config=device_config, diff --git a/python/sglang/srt/model_loader/model_loader.py b/python/sglang/srt/model_loader/model_loader.py deleted file mode 100644 index 4b7e32b6e..000000000 --- a/python/sglang/srt/model_loader/model_loader.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py -# FIXME: in progress of refactoring the model loader - -import glob -import os -import re -from typing import Any, Dict, Generator, List, Optional, Tuple, Type - -import torch -from torch import nn -from tqdm import tqdm -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoadConfig, - LoadFormat, - LoRAConfig, - ModelConfig, - MultiModalConfig, - ParallelConfig, - SchedulerConfig, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.model_loader.utils import ( - get_model_architecture, - set_default_torch_dtype, -) -from vllm.platforms import current_platform - -from sglang.srt.model_loader.utils import ( - download_safetensors_index_file_from_hf, - download_weights_from_hf, - filter_duplicate_safetensors_files, - get_quant_config, - safetensors_weights_iterator, -) - - -def _get_quantization_config( - model_config: ModelConfig, load_config: LoadConfig -) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - if model_config.quantization is not None: - quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}." - ) - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) - return quant_config - return None - - -def _get_model_initialization_kwargs( - model_class: Type[nn.Module], - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], -) -> Dict[str, Any]: - """Get extra kwargs for model initialization.""" - extra_kwargs: Dict[str, Any] = {} - - assert lora_config is None - assert multimodal_config is None - - return extra_kwargs - - -def _initialize_model( - model_config: ModelConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig, -) -> nn.Module: - """Initialize a model with the given configurations.""" - model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config(model_config, load_config) - - return model_class( - config=model_config.hf_config, - cache_config=cache_config, - quant_config=quant_config, - efficient_weight_load=True, - **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), - ) - - -class ModelLoader: - """Model loader that can load different file types from disk.""" - - def __init__(self, load_config: LoadConfig): - self.load_config = load_config - - def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool - ) -> Tuple[str, List[str], bool]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - - is_local = os.path.isdir(model_name_or_path) - load_format = self.load_config.load_format - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == LoadFormat.AUTO: - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == LoadFormat.SAFETENSORS: - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == LoadFormat.PT: - allow_patterns = ["*.pt"] - elif load_format == LoadFormat.NPCACHE: - allow_patterns = ["*.bin"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ) - else: - hf_folder = model_name_or_path - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, self.load_config.download_dir, revision - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder - ) - else: - hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`" - ) - - return hf_folder, hf_weights_files, use_safetensors - - def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt - ) - if self.load_config.load_format == LoadFormat.NPCACHE: - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - weights_iterator = np_cache_weights_iterator( - model_name_or_path, - self.load_config.download_dir, - hf_folder, - hf_weights_files, - ) - elif use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) - else: - weights_iterator = pt_weights_iterator(hf_weights_files) - - return weights_iterator - - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model( - model_config, - self.load_config, - lora_config, - multimodal_config, - cache_config, - ) - weights = self._get_weights_iterator( - model_config.model, - model_config.revision, - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), - ) - - modules = {} - for name, module in model.named_modules(): - modules[name] = module - - def apply_quant_method(module): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # print("before apply quant", module.weight, module.weight.dtype) - quant_method.process_weights_after_loading(module) - # print("after apply quant", module.weight, module.weight.dtype) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - - if torch.cuda.current_device() == 0: - weights = tqdm( - weights, total=model.get_num_params() * 1.5, desc="load model" - ) - - num_shard = {} - num_loaded = {} - for name, loaded_weight in weights: - model.load_weights(None, name, loaded_weight) - module_name, shard_num = model.get_module_name(name) - num_shard[module_name] = shard_num - if module_name not in num_loaded: - num_loaded[module_name] = 1 - else: - num_loaded[module_name] += 1 - if num_loaded[module_name] == num_shard[module_name]: - apply_quant_method(modules[module_name]) - - return model.eval() - - -def get_model( - *, - model_config: ModelConfig, - load_config: LoadConfig, - device_config: DeviceConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig, -) -> nn.Module: - loader = ModelLoader(load_config) - return loader.load_model( - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py deleted file mode 100644 index 9d6520e2a..000000000 --- a/python/sglang/srt/model_loader/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from vLLM -# FIXME: in progress of refactoring the model loader -"""Utilities for selecting and loading models.""" -import contextlib -import fnmatch -import hashlib -import json -import logging -import os -import tempfile -from typing import Any, Generator, Iterable, List, Optional, Tuple, Type - -import filelock -import huggingface_hub.constants -import torch -from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file -from torch import nn -from tqdm.auto import tqdm -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - -from sglang.srt.layers.quantization import get_quantization_config - -logger = logging.getLogger(__name__) -temp_dir = tempfile.gettempdir() - - -@contextlib.contextmanager -def set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - -def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if ( - model_config.quantization is not None - and model_config.quantization != "fp8" - and "MixtralForCausalLM" in architectures - ): - architectures = ["QuantMixtralForCausalLM"] - - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}" - ) - - -class DisabledTqdm(tqdm): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, disable=True) - - -def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): - lock_dir = cache_dir or temp_dir - os.makedirs(os.path.dirname(lock_dir), exist_ok=True) - model_name = model_name_or_path.replace("/", "-") - hash_name = hashlib.sha256(model_name.encode()).hexdigest() - # add hash to avoid conflict with old users' lock files - lock_file_name = hash_name + model_name + ".lock" - # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) - return lock - - -def download_weights_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None, -) -> str: - """Download model weights from Hugging Face Hub. - - Args: - model_name_or_path (str): The model name or path. - cache_dir (Optional[str]): The cache directory to store the model - weights. If None, will use HF defaults. - allow_patterns (List[str]): The allowed patterns for the - weight files. Files matched by any of the patterns will be - downloaded. - revision (Optional[str]): The revision of the model. - - Returns: - str: The path to the downloaded model weights. - """ - if not huggingface_hub.constants.HF_HUB_OFFLINE: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info("Using model weights format %s", allow_patterns) - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ) - return hf_folder - - -def download_safetensors_index_file_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - revision: Optional[str] = None, -) -> None: - """Download hf safetensors index file from Hugging Face Hub. - - Args: - model_name_or_path (str): The model name or path. - cache_dir (Optional[str]): The cache directory to store the model - weights. If None, will use HF defaults. - revision (Optional[str]): The revision of the model. - """ - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - try: - # Download the safetensors index file. - hf_hub_download( - repo_id=model_name_or_path, - filename=SAFE_WEIGHTS_INDEX_NAME, - cache_dir=cache_dir, - revision=revision, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ) - # If file not found on remote or locally, we should not fail since - # only some models will have SAFE_WEIGHTS_INDEX_NAME. - except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) - except huggingface_hub.utils.LocalEntryNotFoundError: - logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) - - -# For models like Mistral-7B-v0.3, there are both sharded -# safetensors files and a consolidated safetensors file. -# Passing both of these to the weight loader functionality breaks. -# So, we use the SAFE_WEIGHTS_INDEX_NAME to -# look up which safetensors files should be used. -def filter_duplicate_safetensors_files( - hf_weights_files: List[str], hf_folder: str -) -> List[str]: - # model.safetensors.index.json is a mapping from keys in the - # torch state_dict to safetensors file holding that weight. - index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) - if not os.path.isfile(index_file_name): - return hf_weights_files - - # Iterate through the weight_map (weight_name: safetensors files) - # to identify weights that we should use. - with open(index_file_name) as index_file: - weight_map = json.load(index_file)["weight_map"] - weight_files_in_index = set() - for weight_name in weight_map: - weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) - # Filter out any fields that are not found in the index file. - hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] - return hf_weights_files - - -def safetensors_weights_iterator( - hf_weights_files: List[str], -) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - - -def get_quant_config( - model_config: ModelConfig, load_config: LoadConfig -) -> QuantizationConfig: - quant_cls = get_quantization_config(model_config.quantization) - # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) - if hf_quant_config is None: - # compressed-tensors uses a compressions_config - hf_quant_config = getattr(model_config.hf_config, "compression_config", None) - if hf_quant_config is not None: - return quant_cls.from_config(hf_quant_config) - # In case of bitsandbytes/QLoRA, get quant config from the adapter model. - if model_config.quantization == "bitsandbytes": - if ( - not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config - ): - return quant_cls.from_config({"adapter_name_or_path": ""}) - model_name_or_path = load_config.model_loader_extra_config[ - "qlora_adapter_name_or_path" - ] - - else: - model_name_or_path = model_config.model - is_local = os.path.isdir(model_name_or_path) - if not is_local: - # Download the config files. - with get_lock(model_name_or_path, load_config.download_dir): - hf_folder = snapshot_download( - model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=load_config.download_dir, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - tqdm_class=DisabledTqdm, - ) - else: - hf_folder = model_name_or_path - - possible_config_filenames = quant_cls.get_config_filenames() - - # If the quantization config is not found, use the default config. - if not possible_config_filenames: - return quant_cls() - - config_files = glob.glob(os.path.join(hf_folder, "*.json")) - - quant_config_files = [ - f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) - ] - if len(quant_config_files) == 0: - raise ValueError(f"Cannot find the config file for {model_config.quantization}") - if len(quant_config_files) > 1: - raise ValueError( - f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}" - ) - - quant_config_file = quant_config_files[0] - with open(quant_config_file, "r") as f: - config = json.load(f) - - if model_config.quantization == "bitsandbytes": - config["adapter_name_or_path"] = model_name_or_path - - return quant_cls.from_config(config)