misc: rm unused model_loader (#1110)
This commit is contained in:
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -168,15 +169,6 @@ class ModelRunner:
|
|||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||||
|
|
||||||
if (
|
|
||||||
self.server_args.efficient_weight_load
|
|
||||||
and "llama" in self.server_args.model_path.lower()
|
|
||||||
and self.server_args.quantization == "fp8"
|
|
||||||
):
|
|
||||||
from sglang.srt.model_loader.model_loader import get_model
|
|
||||||
else:
|
|
||||||
from vllm.model_executor.model_loader import get_model
|
|
||||||
|
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=vllm_model_config,
|
model_config=vllm_model_config,
|
||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
|
|||||||
@@ -1,292 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
|
|
||||||
# FIXME: in progress of refactoring the model loader
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from tqdm import tqdm
|
|
||||||
from vllm.config import (
|
|
||||||
CacheConfig,
|
|
||||||
DeviceConfig,
|
|
||||||
LoadConfig,
|
|
||||||
LoadFormat,
|
|
||||||
LoRAConfig,
|
|
||||||
ModelConfig,
|
|
||||||
MultiModalConfig,
|
|
||||||
ParallelConfig,
|
|
||||||
SchedulerConfig,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
||||||
from vllm.model_executor.model_loader.utils import (
|
|
||||||
get_model_architecture,
|
|
||||||
set_default_torch_dtype,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from sglang.srt.model_loader.utils import (
|
|
||||||
download_safetensors_index_file_from_hf,
|
|
||||||
download_weights_from_hf,
|
|
||||||
filter_duplicate_safetensors_files,
|
|
||||||
get_quant_config,
|
|
||||||
safetensors_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_quantization_config(
|
|
||||||
model_config: ModelConfig, load_config: LoadConfig
|
|
||||||
) -> Optional[QuantizationConfig]:
|
|
||||||
"""Get the quantization config."""
|
|
||||||
if model_config.quantization is not None:
|
|
||||||
quant_config = get_quant_config(model_config, load_config)
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
capability = capability[0] * 10 + capability[1]
|
|
||||||
if capability < quant_config.get_min_capability():
|
|
||||||
raise ValueError(
|
|
||||||
f"The quantization method {model_config.quantization} is not "
|
|
||||||
"supported for the current GPU. "
|
|
||||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
|
||||||
f"Current capability: {capability}."
|
|
||||||
)
|
|
||||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
|
||||||
if model_config.dtype not in supported_dtypes:
|
|
||||||
raise ValueError(
|
|
||||||
f"{model_config.dtype} is not supported for quantization "
|
|
||||||
f"method {model_config.quantization}. Supported dtypes: "
|
|
||||||
f"{supported_dtypes}"
|
|
||||||
)
|
|
||||||
return quant_config
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_initialization_kwargs(
|
|
||||||
model_class: Type[nn.Module],
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get extra kwargs for model initialization."""
|
|
||||||
extra_kwargs: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
assert lora_config is None
|
|
||||||
assert multimodal_config is None
|
|
||||||
|
|
||||||
return extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_model(
|
|
||||||
model_config: ModelConfig,
|
|
||||||
load_config: LoadConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""Initialize a model with the given configurations."""
|
|
||||||
model_class = get_model_architecture(model_config)[0]
|
|
||||||
quant_config = _get_quantization_config(model_config, load_config)
|
|
||||||
|
|
||||||
return model_class(
|
|
||||||
config=model_config.hf_config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
efficient_weight_load=True,
|
|
||||||
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader:
|
|
||||||
"""Model loader that can load different file types from disk."""
|
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
|
||||||
self.load_config = load_config
|
|
||||||
|
|
||||||
def _prepare_weights(
|
|
||||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
|
||||||
) -> Tuple[str, List[str], bool]:
|
|
||||||
"""Prepare weights for the model.
|
|
||||||
|
|
||||||
If the model is not local, it will be downloaded."""
|
|
||||||
|
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
|
||||||
load_format = self.load_config.load_format
|
|
||||||
use_safetensors = False
|
|
||||||
# Some quantized models use .pt files for storing the weights.
|
|
||||||
if load_format == LoadFormat.AUTO:
|
|
||||||
allow_patterns = ["*.safetensors", "*.bin"]
|
|
||||||
elif load_format == LoadFormat.SAFETENSORS:
|
|
||||||
use_safetensors = True
|
|
||||||
allow_patterns = ["*.safetensors"]
|
|
||||||
elif load_format == LoadFormat.PT:
|
|
||||||
allow_patterns = ["*.pt"]
|
|
||||||
elif load_format == LoadFormat.NPCACHE:
|
|
||||||
allow_patterns = ["*.bin"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown load_format: {load_format}")
|
|
||||||
|
|
||||||
if fall_back_to_pt:
|
|
||||||
allow_patterns += ["*.pt"]
|
|
||||||
|
|
||||||
if not is_local:
|
|
||||||
hf_folder = download_weights_from_hf(
|
|
||||||
model_name_or_path,
|
|
||||||
self.load_config.download_dir,
|
|
||||||
allow_patterns,
|
|
||||||
revision,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hf_folder = model_name_or_path
|
|
||||||
|
|
||||||
hf_weights_files: List[str] = []
|
|
||||||
for pattern in allow_patterns:
|
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
|
||||||
if len(hf_weights_files) > 0:
|
|
||||||
if pattern == "*.safetensors":
|
|
||||||
use_safetensors = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if use_safetensors:
|
|
||||||
# For models like Mistral-7B-Instruct-v0.3
|
|
||||||
# there are both sharded safetensors files and a consolidated
|
|
||||||
# safetensors file. Using both breaks.
|
|
||||||
# Here, we download the `model.safetensors.index.json` and filter
|
|
||||||
# any files not found in the index.
|
|
||||||
if not is_local:
|
|
||||||
download_safetensors_index_file_from_hf(
|
|
||||||
model_name_or_path, self.load_config.download_dir, revision
|
|
||||||
)
|
|
||||||
hf_weights_files = filter_duplicate_safetensors_files(
|
|
||||||
hf_weights_files, hf_folder
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return hf_folder, hf_weights_files, use_safetensors
|
|
||||||
|
|
||||||
def _get_weights_iterator(
|
|
||||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
|
||||||
"""Get an iterator for the model weights based on the load format."""
|
|
||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
|
||||||
model_name_or_path, revision, fall_back_to_pt
|
|
||||||
)
|
|
||||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
|
||||||
# Currently np_cache only support *.bin checkpoints
|
|
||||||
assert use_safetensors is False
|
|
||||||
weights_iterator = np_cache_weights_iterator(
|
|
||||||
model_name_or_path,
|
|
||||||
self.load_config.download_dir,
|
|
||||||
hf_folder,
|
|
||||||
hf_weights_files,
|
|
||||||
)
|
|
||||||
elif use_safetensors:
|
|
||||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
|
||||||
else:
|
|
||||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
|
||||||
|
|
||||||
return weights_iterator
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
) -> nn.Module:
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
|
||||||
with torch.device(device_config.device):
|
|
||||||
model = _initialize_model(
|
|
||||||
model_config,
|
|
||||||
self.load_config,
|
|
||||||
lora_config,
|
|
||||||
multimodal_config,
|
|
||||||
cache_config,
|
|
||||||
)
|
|
||||||
weights = self._get_weights_iterator(
|
|
||||||
model_config.model,
|
|
||||||
model_config.revision,
|
|
||||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
modules = {}
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
modules[name] = module
|
|
||||||
|
|
||||||
def apply_quant_method(module):
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
# print("before apply quant", module.weight, module.weight.dtype)
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
# print("after apply quant", module.weight, module.weight.dtype)
|
|
||||||
# FIXME: Remove this after Mixtral is updated
|
|
||||||
# to use quant_method.
|
|
||||||
if hasattr(module, "process_weights_after_loading"):
|
|
||||||
module.process_weights_after_loading()
|
|
||||||
|
|
||||||
if torch.cuda.current_device() == 0:
|
|
||||||
weights = tqdm(
|
|
||||||
weights, total=model.get_num_params() * 1.5, desc="load model"
|
|
||||||
)
|
|
||||||
|
|
||||||
num_shard = {}
|
|
||||||
num_loaded = {}
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
model.load_weights(None, name, loaded_weight)
|
|
||||||
module_name, shard_num = model.get_module_name(name)
|
|
||||||
num_shard[module_name] = shard_num
|
|
||||||
if module_name not in num_loaded:
|
|
||||||
num_loaded[module_name] = 1
|
|
||||||
else:
|
|
||||||
num_loaded[module_name] += 1
|
|
||||||
if num_loaded[module_name] == num_shard[module_name]:
|
|
||||||
apply_quant_method(modules[module_name])
|
|
||||||
|
|
||||||
return model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
*,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
load_config: LoadConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
) -> nn.Module:
|
|
||||||
loader = ModelLoader(load_config)
|
|
||||||
return loader.load_model(
|
|
||||||
model_config=model_config,
|
|
||||||
device_config=device_config,
|
|
||||||
lora_config=lora_config,
|
|
||||||
multimodal_config=multimodal_config,
|
|
||||||
parallel_config=parallel_config,
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
)
|
|
||||||
@@ -1,275 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# temporarily adapted from vLLM
|
|
||||||
# FIXME: in progress of refactoring the model loader
|
|
||||||
"""Utilities for selecting and loading models."""
|
|
||||||
import contextlib
|
|
||||||
import fnmatch
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type
|
|
||||||
|
|
||||||
import filelock
|
|
||||||
import huggingface_hub.constants
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
|
||||||
from safetensors.torch import load_file, safe_open, save_file
|
|
||||||
from torch import nn
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
||||||
from vllm.config import LoadConfig, ModelConfig
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization import get_quantization_config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def set_default_torch_dtype(dtype: torch.dtype):
|
|
||||||
"""Sets the default torch dtype to the given dtype."""
|
|
||||||
old_dtype = torch.get_default_dtype()
|
|
||||||
torch.set_default_dtype(dtype)
|
|
||||||
yield
|
|
||||||
torch.set_default_dtype(old_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
|
||||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
|
||||||
# Special handling for quantized Mixtral.
|
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
|
||||||
if (
|
|
||||||
model_config.quantization is not None
|
|
||||||
and model_config.quantization != "fp8"
|
|
||||||
and "MixtralForCausalLM" in architectures
|
|
||||||
):
|
|
||||||
architectures = ["QuantMixtralForCausalLM"]
|
|
||||||
|
|
||||||
for arch in architectures:
|
|
||||||
model_cls = ModelRegistry.load_model_cls(arch)
|
|
||||||
if model_cls is not None:
|
|
||||||
return (model_cls, arch)
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DisabledTqdm(tqdm):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs, disable=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
|
||||||
lock_dir = cache_dir or temp_dir
|
|
||||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
|
||||||
model_name = model_name_or_path.replace("/", "-")
|
|
||||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
|
||||||
# add hash to avoid conflict with old users' lock files
|
|
||||||
lock_file_name = hash_name + model_name + ".lock"
|
|
||||||
# mode 0o666 is required for the filelock to be shared across users
|
|
||||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
|
||||||
return lock
|
|
||||||
|
|
||||||
|
|
||||||
def download_weights_from_hf(
|
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
allow_patterns: List[str],
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Download model weights from Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name_or_path (str): The model name or path.
|
|
||||||
cache_dir (Optional[str]): The cache directory to store the model
|
|
||||||
weights. If None, will use HF defaults.
|
|
||||||
allow_patterns (List[str]): The allowed patterns for the
|
|
||||||
weight files. Files matched by any of the patterns will be
|
|
||||||
downloaded.
|
|
||||||
revision (Optional[str]): The revision of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path to the downloaded model weights.
|
|
||||||
"""
|
|
||||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
|
||||||
# Before we download we look at that is available:
|
|
||||||
fs = HfFileSystem()
|
|
||||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
|
||||||
|
|
||||||
# depending on what is available we download different things
|
|
||||||
for pattern in allow_patterns:
|
|
||||||
matching = fnmatch.filter(file_list, pattern)
|
|
||||||
if len(matching) > 0:
|
|
||||||
allow_patterns = [pattern]
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("Using model weights format %s", allow_patterns)
|
|
||||||
# Use file lock to prevent multiple processes from
|
|
||||||
# downloading the same model weights at the same time.
|
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
|
||||||
hf_folder = snapshot_download(
|
|
||||||
model_name_or_path,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
tqdm_class=DisabledTqdm,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
||||||
)
|
|
||||||
return hf_folder
|
|
||||||
|
|
||||||
|
|
||||||
def download_safetensors_index_file_from_hf(
|
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Download hf safetensors index file from Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name_or_path (str): The model name or path.
|
|
||||||
cache_dir (Optional[str]): The cache directory to store the model
|
|
||||||
weights. If None, will use HF defaults.
|
|
||||||
revision (Optional[str]): The revision of the model.
|
|
||||||
"""
|
|
||||||
# Use file lock to prevent multiple processes from
|
|
||||||
# downloading the same model weights at the same time.
|
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
|
||||||
try:
|
|
||||||
# Download the safetensors index file.
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=model_name_or_path,
|
|
||||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
||||||
)
|
|
||||||
# If file not found on remote or locally, we should not fail since
|
|
||||||
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
|
|
||||||
except huggingface_hub.utils.EntryNotFoundError:
|
|
||||||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
|
||||||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
# For models like Mistral-7B-v0.3, there are both sharded
|
|
||||||
# safetensors files and a consolidated safetensors file.
|
|
||||||
# Passing both of these to the weight loader functionality breaks.
|
|
||||||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
|
|
||||||
# look up which safetensors files should be used.
|
|
||||||
def filter_duplicate_safetensors_files(
|
|
||||||
hf_weights_files: List[str], hf_folder: str
|
|
||||||
) -> List[str]:
|
|
||||||
# model.safetensors.index.json is a mapping from keys in the
|
|
||||||
# torch state_dict to safetensors file holding that weight.
|
|
||||||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
if not os.path.isfile(index_file_name):
|
|
||||||
return hf_weights_files
|
|
||||||
|
|
||||||
# Iterate through the weight_map (weight_name: safetensors files)
|
|
||||||
# to identify weights that we should use.
|
|
||||||
with open(index_file_name) as index_file:
|
|
||||||
weight_map = json.load(index_file)["weight_map"]
|
|
||||||
weight_files_in_index = set()
|
|
||||||
for weight_name in weight_map:
|
|
||||||
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
|
|
||||||
# Filter out any fields that are not found in the index file.
|
|
||||||
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
|
|
||||||
return hf_weights_files
|
|
||||||
|
|
||||||
|
|
||||||
def safetensors_weights_iterator(
|
|
||||||
hf_weights_files: List[str],
|
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
|
||||||
"""Iterate over the weights in the model safetensor files."""
|
|
||||||
for st_file in hf_weights_files:
|
|
||||||
with safe_open(st_file, framework="pt") as f:
|
|
||||||
for name in f.keys(): # noqa: SIM118
|
|
||||||
param = f.get_tensor(name)
|
|
||||||
yield name, param
|
|
||||||
|
|
||||||
|
|
||||||
def get_quant_config(
|
|
||||||
model_config: ModelConfig, load_config: LoadConfig
|
|
||||||
) -> QuantizationConfig:
|
|
||||||
quant_cls = get_quantization_config(model_config.quantization)
|
|
||||||
# Read the quantization config from the HF model config, if available.
|
|
||||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
|
||||||
if hf_quant_config is None:
|
|
||||||
# compressed-tensors uses a compressions_config
|
|
||||||
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
|
||||||
if hf_quant_config is not None:
|
|
||||||
return quant_cls.from_config(hf_quant_config)
|
|
||||||
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
|
||||||
if model_config.quantization == "bitsandbytes":
|
|
||||||
if (
|
|
||||||
not load_config.model_loader_extra_config
|
|
||||||
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
|
|
||||||
):
|
|
||||||
return quant_cls.from_config({"adapter_name_or_path": ""})
|
|
||||||
model_name_or_path = load_config.model_loader_extra_config[
|
|
||||||
"qlora_adapter_name_or_path"
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
|
||||||
model_name_or_path = model_config.model
|
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
|
||||||
if not is_local:
|
|
||||||
# Download the config files.
|
|
||||||
with get_lock(model_name_or_path, load_config.download_dir):
|
|
||||||
hf_folder = snapshot_download(
|
|
||||||
model_name_or_path,
|
|
||||||
revision=model_config.revision,
|
|
||||||
allow_patterns="*.json",
|
|
||||||
cache_dir=load_config.download_dir,
|
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
||||||
tqdm_class=DisabledTqdm,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hf_folder = model_name_or_path
|
|
||||||
|
|
||||||
possible_config_filenames = quant_cls.get_config_filenames()
|
|
||||||
|
|
||||||
# If the quantization config is not found, use the default config.
|
|
||||||
if not possible_config_filenames:
|
|
||||||
return quant_cls()
|
|
||||||
|
|
||||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
|
||||||
|
|
||||||
quant_config_files = [
|
|
||||||
f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
|
|
||||||
]
|
|
||||||
if len(quant_config_files) == 0:
|
|
||||||
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
|
|
||||||
if len(quant_config_files) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found multiple config files for {model_config.quantization}: "
|
|
||||||
f"{quant_config_files}"
|
|
||||||
)
|
|
||||||
|
|
||||||
quant_config_file = quant_config_files[0]
|
|
||||||
with open(quant_config_file, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
if model_config.quantization == "bitsandbytes":
|
|
||||||
config["adapter_name_or_path"] = model_name_or_path
|
|
||||||
|
|
||||||
return quant_cls.from_config(config)
|
|
||||||
Reference in New Issue
Block a user