misc: rm unused model_loader (#1110)
This commit is contained in:
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -168,15 +169,6 @@ class ModelRunner:
|
||||
if self.model_config.model_overide_args is not None:
|
||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||
|
||||
if (
|
||||
self.server_args.efficient_weight_load
|
||||
and "llama" in self.server_args.model_path.lower()
|
||||
and self.server_args.quantization == "fp8"
|
||||
):
|
||||
from sglang.srt.model_loader.model_loader import get_model
|
||||
else:
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
self.model = get_model(
|
||||
model_config=vllm_model_config,
|
||||
device_config=device_config,
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
|
||||
# FIXME: in progress of refactoring the model loader
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
LoadConfig,
|
||||
LoadFormat,
|
||||
LoRAConfig,
|
||||
ModelConfig,
|
||||
MultiModalConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from sglang.srt.model_loader.utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
get_quant_config,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}."
|
||||
)
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}"
|
||||
)
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module],
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs: Dict[str, Any] = {}
|
||||
|
||||
assert lora_config is None
|
||||
assert multimodal_config is None
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
def _initialize_model(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
quant_config = _get_quantization_config(model_config, load_config)
|
||||
|
||||
return model_class(
|
||||
config=model_config.hf_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
efficient_weight_load=True,
|
||||
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
|
||||
)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == LoadFormat.SAFETENSORS:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.PT:
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, self.load_config.download_dir, revision
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision, fall_back_to_pt
|
||||
)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
)
|
||||
elif use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||
|
||||
return weights_iterator
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(
|
||||
model_config,
|
||||
self.load_config,
|
||||
lora_config,
|
||||
multimodal_config,
|
||||
cache_config,
|
||||
)
|
||||
weights = self._get_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
)
|
||||
|
||||
modules = {}
|
||||
for name, module in model.named_modules():
|
||||
modules[name] = module
|
||||
|
||||
def apply_quant_method(module):
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# print("before apply quant", module.weight, module.weight.dtype)
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# print("after apply quant", module.weight, module.weight.dtype)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
|
||||
if torch.cuda.current_device() == 0:
|
||||
weights = tqdm(
|
||||
weights, total=model.get_num_params() * 1.5, desc="load model"
|
||||
)
|
||||
|
||||
num_shard = {}
|
||||
num_loaded = {}
|
||||
for name, loaded_weight in weights:
|
||||
model.load_weights(None, name, loaded_weight)
|
||||
module_name, shard_num = model.get_module_name(name)
|
||||
num_shard[module_name] = shard_num
|
||||
if module_name not in num_loaded:
|
||||
num_loaded[module_name] = 1
|
||||
else:
|
||||
num_loaded[module_name] += 1
|
||||
if num_loaded[module_name] == num_shard[module_name]:
|
||||
apply_quant_method(modules[module_name])
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model(
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
device_config: DeviceConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
loader = ModelLoader(load_config)
|
||||
return loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
multimodal_config=multimodal_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
@@ -1,275 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# temporarily adapted from vLLM
|
||||
# FIXME: in progress of refactoring the model loader
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import filelock
|
||||
import huggingface_hub.constants
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from sglang.srt.layers.quantization import get_quantization_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
if (
|
||||
model_config.quantization is not None
|
||||
and model_config.quantization != "fp8"
|
||||
and "MixtralForCausalLM" in architectures
|
||||
):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}"
|
||||
)
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir or temp_dir
|
||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||
model_name = model_name_or_path.replace("/", "-")
|
||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (List[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(
|
||||
hf_weights_files: List[str], hf_folder: str
|
||||
) -> List[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as index_file:
|
||||
weight_map = json.load(index_file)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def get_quant_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is None:
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
if (
|
||||
not load_config.model_loader_extra_config
|
||||
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
|
||||
):
|
||||
return quant_cls.from_config({"adapter_name_or_path": ""})
|
||||
model_name_or_path = load_config.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"
|
||||
]
|
||||
|
||||
else:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}"
|
||||
)
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
Reference in New Issue
Block a user