Revert "Revert "[FEAT] Support GGUF format"" (#2287)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -27,6 +28,7 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||
@@ -60,15 +62,29 @@ def get_config(
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = model
|
||||
model = Path(model).parent
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
if model_override_args:
|
||||
config.update(model_override_args)
|
||||
|
||||
# Special architecture mapping check for GGUF models
|
||||
if is_gguf:
|
||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
config.update({"architectures": [model_type]})
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -123,6 +139,11 @@ def get_tokenizer(
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
is_gguf = check_gguf_file(tokenizer_name)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = tokenizer_name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
|
||||
)
|
||||
else:
|
||||
tokenizer.additional_stop_token_ids = None
|
||||
|
||||
|
||||
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
|
||||
"""Check if the file is a GGUF model."""
|
||||
model = Path(model)
|
||||
if not model.is_file():
|
||||
return False
|
||||
elif model.suffix == ".gguf":
|
||||
return True
|
||||
|
||||
with open(model, "rb") as f:
|
||||
header = f.read(4)
|
||||
return header == b"GGUF"
|
||||
|
||||
Reference in New Issue
Block a user