Revert "Revert "[FEAT] Support GGUF format"" (#2287)

This commit is contained in:
Lianmin Zheng
2024-11-30 22:14:48 -08:00
committed by GitHub
parent 1bfa511b95
commit 4936be8acc
41 changed files with 229 additions and 132 deletions

View File

@@ -16,6 +16,7 @@
import contextlib
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download
@@ -27,6 +28,7 @@ from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
@@ -60,15 +62,29 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_override_args:
config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config
@@ -123,6 +139,11 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
)
else:
tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"