同步 b7516
This commit is contained in:
21
gguf-py/LICENSE
Normal file
21
gguf-py/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Georgi Gerganov
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
99
gguf-py/README.md
Normal file
99
gguf-py/README.md
Normal file
@@ -0,0 +1,99 @@
|
||||
## gguf
|
||||
|
||||
This is a Python package for writing binary files in the [GGUF](https://github.com/ggml-org/ggml/pull/302)
|
||||
(GGML Universal File) format.
|
||||
|
||||
See [convert_hf_to_gguf.py](https://github.com/ggml-org/llama.cpp/blob/master/convert_hf_to_gguf.py)
|
||||
as an example for its usage.
|
||||
|
||||
## Installation
|
||||
```sh
|
||||
pip install gguf
|
||||
```
|
||||
|
||||
Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor.
|
||||
```sh
|
||||
pip install gguf[gui]
|
||||
```
|
||||
|
||||
## API Examples/Simple Tools
|
||||
|
||||
[examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
|
||||
|
||||
[examples/reader.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format.
|
||||
|
||||
[gguf/scripts/gguf_dump.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console.
|
||||
|
||||
[gguf/scripts/gguf_set_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key.
|
||||
|
||||
[gguf/scripts/gguf_convert_endian.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files.
|
||||
|
||||
[gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values.
|
||||
|
||||
[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface.
|
||||
|
||||
## Development
|
||||
Maintainers who participate in development of this package are advised to install it in editable mode:
|
||||
|
||||
```sh
|
||||
cd /path/to/llama.cpp/gguf-py
|
||||
|
||||
pip install --editable .
|
||||
```
|
||||
|
||||
**Note**: This may require to upgrade your Pip installation, with a message saying that editable installation currently requires `setup.py`.
|
||||
In this case, upgrade Pip to the latest:
|
||||
|
||||
```sh
|
||||
pip install --upgrade pip
|
||||
```
|
||||
|
||||
## Automatic publishing with CI
|
||||
|
||||
There's a GitHub workflow to make a release automatically upon creation of tags in a specified format.
|
||||
|
||||
1. Bump the version in `pyproject.toml`.
|
||||
2. Create a tag named `gguf-vx.x.x` where `x.x.x` is the semantic version number.
|
||||
|
||||
```sh
|
||||
git tag -a gguf-v1.0.0 -m "Version 1.0 release"
|
||||
```
|
||||
|
||||
3. Push the tags.
|
||||
|
||||
```sh
|
||||
git push origin --tags
|
||||
```
|
||||
|
||||
## Manual publishing
|
||||
If you want to publish the package manually for any reason, you need to have `twine` and `build` installed:
|
||||
|
||||
```sh
|
||||
pip install build twine
|
||||
```
|
||||
|
||||
Then, follow these steps to release a new version:
|
||||
|
||||
1. Bump the version in `pyproject.toml`.
|
||||
2. Build the package:
|
||||
|
||||
```sh
|
||||
python -m build
|
||||
```
|
||||
|
||||
3. Upload the generated distribution archives:
|
||||
|
||||
```sh
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
|
||||
## Run Unit Tests
|
||||
|
||||
From root of this repository you can run this command to run all the unit tests
|
||||
|
||||
```bash
|
||||
python -m unittest discover ./gguf-py -v
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] Include conversion scripts as command line entry points in this package.
|
||||
49
gguf-py/examples/reader.py
Normal file
49
gguf-py/examples/reader.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger("reader")
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
|
||||
def read_gguf_file(gguf_file_path):
|
||||
"""
|
||||
Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.
|
||||
|
||||
Parameters:
|
||||
- gguf_file_path: Path to the GGUF file.
|
||||
"""
|
||||
|
||||
reader = GGUFReader(gguf_file_path)
|
||||
|
||||
# List all key-value pairs in a columnized format
|
||||
print("Key-Value Pairs:") # noqa: NP100
|
||||
max_key_length = max(len(key) for key in reader.fields.keys())
|
||||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
print(f"{key:{max_key_length}} : {value}") # noqa: NP100
|
||||
print("----") # noqa: NP100
|
||||
|
||||
# List all tensors
|
||||
print("Tensors:") # noqa: NP100
|
||||
tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}"
|
||||
print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization")) # noqa: NP100
|
||||
print("-" * 80) # noqa: NP100
|
||||
for tensor in reader.tensors:
|
||||
shape_str = "x".join(map(str, tensor.shape))
|
||||
size_str = str(tensor.n_elements)
|
||||
quantization_str = tensor.tensor_type.name
|
||||
print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
logger.info("Usage: reader.py <path_to_gguf_file>")
|
||||
sys.exit(1)
|
||||
gguf_file_path = sys.argv[1]
|
||||
read_gguf_file(gguf_file_path)
|
||||
39
gguf-py/examples/writer.py
Executable file
39
gguf-py/examples/writer.py
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf import GGUFWriter # noqa: E402
|
||||
|
||||
|
||||
# Example usage:
|
||||
def writer_example() -> None:
|
||||
# Example usage with a file
|
||||
gguf_writer = GGUFWriter("example.gguf", "llama")
|
||||
|
||||
gguf_writer.add_block_count(12)
|
||||
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
|
||||
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
|
||||
gguf_writer.add_custom_alignment(64)
|
||||
|
||||
tensor1 = np.ones((32,), dtype=np.float32) * 100.0
|
||||
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
|
||||
tensor3 = np.ones((96,), dtype=np.float32) * 102.0
|
||||
|
||||
gguf_writer.add_tensor("tensor1", tensor1)
|
||||
gguf_writer.add_tensor("tensor2", tensor2)
|
||||
gguf_writer.add_tensor("tensor3", tensor3)
|
||||
|
||||
gguf_writer.write_header_to_file()
|
||||
gguf_writer.write_kv_data_to_file()
|
||||
gguf_writer.write_tensors_to_file()
|
||||
|
||||
gguf_writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
writer_example()
|
||||
9
gguf-py/gguf/__init__.py
Normal file
9
gguf-py/gguf/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .constants import *
|
||||
from .lazy import *
|
||||
from .gguf_reader import *
|
||||
from .gguf_writer import *
|
||||
from .quants import *
|
||||
from .tensor_mapping import *
|
||||
from .vocab import *
|
||||
from .utility import *
|
||||
from .metadata import *
|
||||
3550
gguf-py/gguf/constants.py
Normal file
3550
gguf-py/gguf/constants.py
Normal file
File diff suppressed because it is too large
Load Diff
15
gguf-py/gguf/gguf.py
Normal file
15
gguf-py/gguf/gguf.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# This file left for compatibility. If you want to use the GGUF API from Python
|
||||
# then don't import gguf/gguf.py directly. If you're looking for examples, see the
|
||||
# examples/ directory for gguf-py
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Compatibility for people trying to import gguf/gguf.py directly instead of as a package.
|
||||
importlib.invalidate_caches()
|
||||
import gguf # noqa: E402
|
||||
|
||||
importlib.reload(gguf)
|
||||
367
gguf-py/gguf/gguf_reader.py
Normal file
367
gguf-py/gguf/gguf_reader.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#
|
||||
# GGUF file reading/modification support. For API usage information,
|
||||
# please see the files scripts/ for some fairly simple examples.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from .quants import quant_shape_to_byte_shape
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running file in package as a script.
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf.constants import (
|
||||
GGML_QUANT_SIZES,
|
||||
GGUF_DEFAULT_ALIGNMENT,
|
||||
GGUF_MAGIC,
|
||||
GGUF_VERSION,
|
||||
GGMLQuantizationType,
|
||||
GGUFValueType,
|
||||
GGUFEndian,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
|
||||
|
||||
|
||||
class ReaderField(NamedTuple):
|
||||
# Offset to start of this field.
|
||||
offset: int
|
||||
|
||||
# Name of the field (not necessarily from file data).
|
||||
name: str
|
||||
|
||||
# Data parts. Some types have multiple components, such as strings
|
||||
# that consist of a length followed by the string data.
|
||||
parts: list[npt.NDArray[Any]] = []
|
||||
|
||||
# Indexes into parts that we can call the actual data. For example
|
||||
# an array of strings will be populated with indexes to the actual
|
||||
# string data.
|
||||
data: list[int] = [-1]
|
||||
|
||||
types: list[GGUFValueType] = []
|
||||
|
||||
def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
|
||||
if self.types:
|
||||
to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
|
||||
main_type = self.types[0]
|
||||
|
||||
if main_type == GGUFValueType.ARRAY:
|
||||
sub_type = self.types[-1]
|
||||
|
||||
if sub_type == GGUFValueType.STRING:
|
||||
indices = self.data[index_or_slice]
|
||||
|
||||
if isinstance(index_or_slice, int):
|
||||
return to_string(self.parts[indices]) # type: ignore
|
||||
else:
|
||||
return [to_string(self.parts[idx]) for idx in indices] # type: ignore
|
||||
else:
|
||||
# FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
|
||||
|
||||
# Check if it's unsafe to perform slice optimization on data
|
||||
# if any(True for idx in self.data if len(self.parts[idx]) != 1):
|
||||
# optim_slice = slice(None)
|
||||
# else:
|
||||
# optim_slice = index_or_slice
|
||||
# index_or_slice = slice(None)
|
||||
|
||||
# if isinstance(optim_slice, int):
|
||||
# return self.parts[self.data[optim_slice]].tolist()[0]
|
||||
# else:
|
||||
# return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
|
||||
|
||||
if isinstance(index_or_slice, int):
|
||||
return self.parts[self.data[index_or_slice]].tolist()[0]
|
||||
else:
|
||||
return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
|
||||
|
||||
if main_type == GGUFValueType.STRING:
|
||||
return to_string(self.parts[-1])
|
||||
else:
|
||||
return self.parts[-1].tolist()[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReaderTensor(NamedTuple):
|
||||
name: str
|
||||
tensor_type: GGMLQuantizationType
|
||||
shape: npt.NDArray[np.uint32]
|
||||
n_elements: int
|
||||
n_bytes: int
|
||||
data_offset: int
|
||||
data: npt.NDArray[Any]
|
||||
field: ReaderField
|
||||
|
||||
|
||||
class GGUFReader:
|
||||
# I - same as host, S - swapped
|
||||
byte_order: Literal['I', 'S'] = 'I'
|
||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||
data_offset: int
|
||||
|
||||
# Note: Internal helper, API may change.
|
||||
gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
|
||||
GGUFValueType.UINT8: np.uint8,
|
||||
GGUFValueType.INT8: np.int8,
|
||||
GGUFValueType.UINT16: np.uint16,
|
||||
GGUFValueType.INT16: np.int16,
|
||||
GGUFValueType.UINT32: np.uint32,
|
||||
GGUFValueType.INT32: np.int32,
|
||||
GGUFValueType.FLOAT32: np.float32,
|
||||
GGUFValueType.UINT64: np.uint64,
|
||||
GGUFValueType.INT64: np.int64,
|
||||
GGUFValueType.FLOAT64: np.float64,
|
||||
GGUFValueType.BOOL: np.bool_,
|
||||
}
|
||||
|
||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
|
||||
self.data = np.memmap(path, mode = mode)
|
||||
offs = 0
|
||||
|
||||
# Check for GGUF magic
|
||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||
raise ValueError('GGUF magic invalid')
|
||||
offs += 4
|
||||
|
||||
# Check GGUF version
|
||||
temp_version = self._get(offs, np.uint32)
|
||||
if temp_version[0] & 65535 == 0:
|
||||
# If we get 0 here that means it's (probably) a GGUF file created for
|
||||
# the opposite byte order of the machine this script is running on.
|
||||
self.byte_order = 'S'
|
||||
temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
|
||||
version = temp_version[0]
|
||||
if version not in READER_SUPPORTED_VERSIONS:
|
||||
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
||||
if sys.byteorder == "little":
|
||||
# Host is little endian
|
||||
host_endian = GGUFEndian.LITTLE
|
||||
swapped_endian = GGUFEndian.BIG
|
||||
else:
|
||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
||||
host_endian = GGUFEndian.BIG
|
||||
swapped_endian = GGUFEndian.LITTLE
|
||||
self.endianess = swapped_endian if self.byte_order == "S" else host_endian
|
||||
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
||||
self.tensors: list[ReaderTensor] = []
|
||||
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
||||
|
||||
# Check tensor count and kv count
|
||||
temp_counts = self._get(offs, np.uint64, 2)
|
||||
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
|
||||
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
|
||||
tensor_count, kv_count = temp_counts
|
||||
offs = self._build_fields(offs, kv_count)
|
||||
|
||||
# Build Tensor Info Fields
|
||||
offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
|
||||
new_align = self.fields.get('general.alignment')
|
||||
if new_align is not None:
|
||||
if new_align.types != [GGUFValueType.UINT32]:
|
||||
raise ValueError('Bad type for general.alignment field')
|
||||
self.alignment = new_align.parts[-1][0]
|
||||
padding = offs % self.alignment
|
||||
if padding != 0:
|
||||
offs += self.alignment - padding
|
||||
self.data_offset = offs
|
||||
self._build_tensors(offs, tensors_fields)
|
||||
|
||||
_DT = TypeVar('_DT', bound = npt.DTypeLike)
|
||||
|
||||
# Fetch a key/value metadata field by key.
|
||||
def get_field(self, key: str) -> Union[ReaderField, None]:
|
||||
return self.fields.get(key, None)
|
||||
|
||||
# Fetch a tensor from the list by index.
|
||||
def get_tensor(self, idx: int) -> ReaderTensor:
|
||||
return self.tensors[idx]
|
||||
|
||||
def _get(
|
||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
|
||||
) -> npt.NDArray[Any]:
|
||||
count = int(count)
|
||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||
end_offs = offset + itemsize * count
|
||||
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
|
||||
return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
|
||||
|
||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||
if field.name in self.fields:
|
||||
# TODO: add option to generate error on duplicate keys
|
||||
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
||||
|
||||
logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
|
||||
self.fields[field.name + '_{}'.format(field.offset)] = field
|
||||
else:
|
||||
self.fields[field.name] = field
|
||||
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
||||
|
||||
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
|
||||
slen = self._get(offset, np.uint64)
|
||||
return slen, self._get(offset + 8, np.uint8, slen[0])
|
||||
|
||||
def _get_field_parts(
|
||||
self, orig_offs: int, raw_type: int,
|
||||
) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
|
||||
offs = orig_offs
|
||||
types: list[GGUFValueType] = []
|
||||
gtype = GGUFValueType(raw_type)
|
||||
types.append(gtype)
|
||||
# Handle strings.
|
||||
if gtype == GGUFValueType.STRING:
|
||||
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
|
||||
size = sum(int(part.nbytes) for part in sparts)
|
||||
return size, sparts, [1], types
|
||||
# Check if it's a simple scalar type.
|
||||
nptype = self.gguf_scalar_to_np.get(gtype)
|
||||
if nptype is not None:
|
||||
val = self._get(offs, nptype)
|
||||
return int(val.nbytes), [val], [0], types
|
||||
# Handle arrays.
|
||||
if gtype == GGUFValueType.ARRAY:
|
||||
raw_itype = self._get(offs, np.uint32)
|
||||
offs += int(raw_itype.nbytes)
|
||||
alen = self._get(offs, np.uint64)
|
||||
offs += int(alen.nbytes)
|
||||
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
||||
data_idxs: list[int] = []
|
||||
# FIXME: Handle multi-dimensional arrays properly instead of flattening
|
||||
for idx in range(alen[0]):
|
||||
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
||||
if idx == 0:
|
||||
types += curr_types
|
||||
idxs_offs = len(aparts)
|
||||
aparts += curr_parts
|
||||
data_idxs += (idx + idxs_offs for idx in curr_idxs)
|
||||
offs += curr_size
|
||||
return offs - orig_offs, aparts, data_idxs, types
|
||||
# We can't deal with this one.
|
||||
raise ValueError(f'Unknown/unhandled field type {gtype}')
|
||||
|
||||
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
|
||||
offs = orig_offs
|
||||
|
||||
# Get Tensor Name
|
||||
name_len, name_data = self._get_str(offs)
|
||||
offs += int(name_len.nbytes + name_data.nbytes)
|
||||
|
||||
# Get Tensor Dimensions Count
|
||||
n_dims = self._get(offs, np.uint32)
|
||||
offs += int(n_dims.nbytes)
|
||||
|
||||
# Get Tensor Dimension Array
|
||||
dims = self._get(offs, np.uint64, n_dims[0])
|
||||
offs += int(dims.nbytes)
|
||||
|
||||
# Get Tensor Encoding Scheme Type
|
||||
raw_dtype = self._get(offs, np.uint32)
|
||||
offs += int(raw_dtype.nbytes)
|
||||
|
||||
# Get Tensor Offset
|
||||
offset_tensor = self._get(offs, np.uint64)
|
||||
offs += int(offset_tensor.nbytes)
|
||||
|
||||
return ReaderField(
|
||||
orig_offs,
|
||||
str(bytes(name_data), encoding = 'utf-8'),
|
||||
[name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
|
||||
[1, 3, 4, 5],
|
||||
)
|
||||
|
||||
def _build_fields(self, offs: int, count: int) -> int:
|
||||
for _ in range(count):
|
||||
orig_offs = offs
|
||||
kv_klen, kv_kdata = self._get_str(offs)
|
||||
offs += int(kv_klen.nbytes + kv_kdata.nbytes)
|
||||
raw_kv_type = self._get(offs, np.uint32)
|
||||
offs += int(raw_kv_type.nbytes)
|
||||
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
|
||||
idxs_offs = len(parts)
|
||||
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
|
||||
parts += field_parts
|
||||
self._push_field(ReaderField(
|
||||
orig_offs,
|
||||
str(bytes(kv_kdata), encoding = 'utf-8'),
|
||||
parts,
|
||||
[idx + idxs_offs for idx in field_idxs],
|
||||
field_types,
|
||||
), skip_sum = True)
|
||||
offs += field_size
|
||||
return offs
|
||||
|
||||
def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
|
||||
tensor_fields = []
|
||||
for _ in range(count):
|
||||
field = self._get_tensor_info_field(offs)
|
||||
offs += sum(int(part.nbytes) for part in field.parts)
|
||||
tensor_fields.append(field)
|
||||
return offs, tensor_fields
|
||||
|
||||
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
|
||||
tensors = []
|
||||
tensor_names = set() # keep track of name to prevent duplicated tensors
|
||||
for field in fields:
|
||||
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
|
||||
# check if there's any tensor having same name already in the list
|
||||
tensor_name = str(bytes(name_data), encoding = 'utf-8')
|
||||
if tensor_name in tensor_names:
|
||||
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
||||
tensor_names.add(tensor_name)
|
||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||
n_elems = int(np.prod(dims))
|
||||
np_dims = tuple(reversed(dims.tolist()))
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elems * type_size // block_size
|
||||
data_offs = int(start_offs + offset_tensor[0])
|
||||
item_type: npt.DTypeLike
|
||||
if ggml_type == GGMLQuantizationType.F16:
|
||||
item_count = n_elems
|
||||
item_type = np.float16
|
||||
elif ggml_type == GGMLQuantizationType.F32:
|
||||
item_count = n_elems
|
||||
item_type = np.float32
|
||||
elif ggml_type == GGMLQuantizationType.F64:
|
||||
item_count = n_elems
|
||||
item_type = np.float64
|
||||
elif ggml_type == GGMLQuantizationType.I8:
|
||||
item_count = n_elems
|
||||
item_type = np.int8
|
||||
elif ggml_type == GGMLQuantizationType.I16:
|
||||
item_count = n_elems
|
||||
item_type = np.int16
|
||||
elif ggml_type == GGMLQuantizationType.I32:
|
||||
item_count = n_elems
|
||||
item_type = np.int32
|
||||
elif ggml_type == GGMLQuantizationType.I64:
|
||||
item_count = n_elems
|
||||
item_type = np.int64
|
||||
else:
|
||||
item_count = n_bytes
|
||||
item_type = np.uint8
|
||||
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
|
||||
tensors.append(ReaderTensor(
|
||||
name = tensor_name,
|
||||
tensor_type = ggml_type,
|
||||
shape = dims,
|
||||
n_elements = n_elems,
|
||||
n_bytes = n_bytes,
|
||||
data_offset = data_offs,
|
||||
data = self._get(data_offs, item_type, item_count).reshape(np_dims),
|
||||
field = field,
|
||||
))
|
||||
self.tensors = tensors
|
||||
1233
gguf-py/gguf/gguf_writer.py
Normal file
1233
gguf-py/gguf/gguf_writer.py
Normal file
File diff suppressed because it is too large
Load Diff
228
gguf-py/gguf/lazy.py
Normal file
228
gguf-py/gguf/lazy.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LazyMeta(ABCMeta):
|
||||
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
meta_attr = getattr(self._meta, name)
|
||||
if callable(meta_attr):
|
||||
return type(self)._wrap_fn(
|
||||
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
|
||||
use_self=self,
|
||||
)
|
||||
elif isinstance(meta_attr, self._tensor_type):
|
||||
# e.g. self.T with torch.Tensor should still be wrapped
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
|
||||
else:
|
||||
# no need to wrap non-tensor properties,
|
||||
# and they likely don't depend on the actual contents of the tensor
|
||||
return meta_attr
|
||||
|
||||
namespace["__getattr__"] = __getattr__
|
||||
|
||||
# need to make a builder for the wrapped wrapper to copy the name,
|
||||
# or else it fails with very cryptic error messages,
|
||||
# because somehow the same string would end up in every closures
|
||||
def mk_wrap(op_name: str, *, meta_noop: bool = False):
|
||||
# need to wrap the wrapper to get self
|
||||
def wrapped_special_op(self, *args, **kwargs):
|
||||
return type(self)._wrap_fn(
|
||||
getattr(type(self)._tensor_type, op_name),
|
||||
meta_noop=meta_noop,
|
||||
)(self, *args, **kwargs)
|
||||
return wrapped_special_op
|
||||
|
||||
# special methods bypass __getattr__, so they need to be added manually
|
||||
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
|
||||
# NOTE: doing this from a metaclass is very convenient
|
||||
# TODO: make this even more comprehensive
|
||||
for binary_op in (
|
||||
"lt", "le", "eq", "ne", "ge", "gt",
|
||||
"add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
|
||||
"or", "pow", "rshift", "sub", "truediv", "xor",
|
||||
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
|
||||
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
|
||||
):
|
||||
attr_name = f"__{binary_op}__"
|
||||
# evaluation on the meta tensor is needed in case there's broadcasting
|
||||
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
|
||||
|
||||
for unary_op in ("not", "abs", "invert", "neg", "pos"):
|
||||
attr_name = f"__{unary_op}__"
|
||||
# the result of these operators usually has the same shape and dtype as the input,
|
||||
# so evaluation on the meta tensor can be skipped.
|
||||
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
|
||||
|
||||
for special_op in (
|
||||
"getitem", "setitem", "len",
|
||||
):
|
||||
attr_name = f"__{special_op}__"
|
||||
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
|
||||
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
|
||||
|
||||
# Tree of lazy tensors
|
||||
class LazyBase(ABC, metaclass=LazyMeta):
|
||||
_tensor_type: type
|
||||
_meta: Any
|
||||
_data: Any | None
|
||||
_args: tuple
|
||||
_kwargs: dict[str, Any]
|
||||
_func: Callable[[Any], Any] | None
|
||||
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
|
||||
super().__init__()
|
||||
self._meta = meta
|
||||
self._data = data
|
||||
self._args = args
|
||||
self._kwargs = kwargs if kwargs is not None else {}
|
||||
self._func = func
|
||||
assert self._func is not None or self._data is not None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
if "_tensor_type" not in cls.__dict__:
|
||||
raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
|
||||
return super().__init_subclass__()
|
||||
|
||||
@staticmethod
|
||||
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
|
||||
# TODO: dict and set
|
||||
if isinstance(o, (list, tuple)):
|
||||
L = []
|
||||
for item in o:
|
||||
L.append(LazyBase._recurse_apply(item, fn))
|
||||
if isinstance(o, tuple):
|
||||
L = tuple(L)
|
||||
return L
|
||||
elif isinstance(o, LazyBase):
|
||||
return fn(o)
|
||||
else:
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = ((use_self,) if use_self is not None else ()) + args
|
||||
|
||||
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||
# TODO: maybe handle tensors in kwargs too
|
||||
|
||||
if isinstance(meta_noop, bool) and not meta_noop:
|
||||
try:
|
||||
res = fn(*meta_args, **kwargs)
|
||||
except NotImplementedError:
|
||||
# running some operations on PyTorch's Meta tensors can cause this exception
|
||||
res = None
|
||||
else:
|
||||
# some operators don't need to actually run on the meta tensors
|
||||
assert len(args) > 0
|
||||
res = args[0]
|
||||
assert isinstance(res, cls)
|
||||
res = res._meta
|
||||
# allow operations to override the dtype and shape
|
||||
if meta_noop is not True:
|
||||
if isinstance(meta_noop, tuple):
|
||||
dtype, shape = meta_noop
|
||||
assert callable(shape)
|
||||
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
|
||||
else:
|
||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
||||
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
|
||||
# share the evaluation between lazy tuple elements
|
||||
shared_args: list = [args, None]
|
||||
|
||||
def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
|
||||
assert len(a) == 2
|
||||
if a[1] is None:
|
||||
a[1] = fn(*a[0], **kw)
|
||||
return a[1][i]
|
||||
return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
|
||||
else:
|
||||
del res # not needed
|
||||
# non-tensor return likely relies on the contents of the args
|
||||
# (e.g. the result of torch.equal)
|
||||
eager_args = cls.to_eager(args)
|
||||
return fn(*eager_args, **kwargs)
|
||||
return wrapped_fn
|
||||
|
||||
@classmethod
|
||||
def to_eager(cls, t: Any) -> Any:
|
||||
def simple_to_eager(_t: LazyBase) -> Any:
|
||||
if _t._data is not None:
|
||||
return _t._data
|
||||
|
||||
# NOTE: there's a recursion limit in Python (usually 1000)
|
||||
|
||||
assert _t._func is not None
|
||||
_t._args = cls._recurse_apply(_t._args, simple_to_eager)
|
||||
_t._data = _t._func(*_t._args, **_t._kwargs)
|
||||
# sanity check
|
||||
assert _t._data is not None
|
||||
assert _t._data.dtype == _t._meta.dtype
|
||||
assert _t._data.shape == _t._meta.shape
|
||||
|
||||
return _t._data
|
||||
|
||||
# recurse into lists and/or tuples, keeping their structure
|
||||
return cls._recurse_apply(t, simple_to_eager)
|
||||
|
||||
@classmethod
|
||||
def eager_to_meta(cls, t: Any) -> Any:
|
||||
return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
|
||||
|
||||
# must be overridden, meta tensor init is backend-specific
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
|
||||
|
||||
@classmethod
|
||||
def from_eager(cls, t: Any) -> Any:
|
||||
if type(t) is cls:
|
||||
# already lazy
|
||||
return t
|
||||
elif isinstance(t, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(t), data=t)
|
||||
else:
|
||||
return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
|
||||
|
||||
|
||||
class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
shape: tuple[int, ...] # Makes the type checker happy in quants.py
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
|
||||
# The initial idea was to use np.nan as the fill value,
|
||||
# but non-float types like np.int16 can't use that.
|
||||
# So zero it is.
|
||||
cheat = np.zeros(1, dtype)
|
||||
return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
|
||||
|
||||
def astype(self, dtype, *args, **kwargs):
|
||||
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
||||
full_args = (self, dtype,) + args
|
||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
|
||||
|
||||
def tofile(self, *args, **kwargs):
|
||||
eager = LazyNumpyTensor.to_eager(self)
|
||||
return eager.tofile(*args, **kwargs)
|
||||
|
||||
# TODO: __array_function__
|
||||
731
gguf-py/gguf/metadata.py
Normal file
731
gguf-py/gguf/metadata.py
Normal file
@@ -0,0 +1,731 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .constants import Keys
|
||||
|
||||
import gguf
|
||||
|
||||
logger = logging.getLogger("metadata")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
# Recommended Sampler Parameters to be written to GGUF KV Store
|
||||
sampling_sequence: Optional[str] = None
|
||||
sampling_top_k: Optional[int] = None
|
||||
sampling_top_p: Optional[float] = None
|
||||
sampling_min_p: Optional[float] = None
|
||||
sampling_xtc_probability: Optional[float] = None
|
||||
sampling_xtc_threshold: Optional[float] = None
|
||||
sampling_temp: Optional[float] = None
|
||||
sampling_penalty_last_n: Optional[int] = None
|
||||
sampling_penalty_repeat: Optional[float] = None
|
||||
sampling_mirostat: Optional[int] = None
|
||||
sampling_mirostat_tau: Optional[float] = None
|
||||
sampling_mirostat_eta: Optional[float] = None
|
||||
|
||||
# Authorship Metadata to be written to GGUF KV Store
|
||||
name: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
finetune: Optional[str] = None
|
||||
basename: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
quantized_by: Optional[str] = None
|
||||
size_label: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
doi: Optional[str] = None
|
||||
uuid: Optional[str] = None
|
||||
repo_url: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_doi: Optional[str] = None
|
||||
source_uuid: Optional[str] = None
|
||||
source_repo_url: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
license_name: Optional[str] = None
|
||||
license_link: Optional[str] = None
|
||||
base_models: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
datasets: Optional[list[dict]] = None
|
||||
|
||||
@staticmethod
|
||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
||||
# This grabs as many contextual authorship metadata as possible from the model repository
|
||||
# making any conversion as required to match the gguf kv store metadata format
|
||||
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
||||
|
||||
# Create a new Metadata instance
|
||||
metadata = Metadata()
|
||||
|
||||
model_card = Metadata.load_model_card(model_path)
|
||||
hf_params = Metadata.load_hf_parameters(model_path)
|
||||
gen_config = Metadata.load_generation_config(model_path)
|
||||
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
|
||||
|
||||
# heuristics
|
||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||
|
||||
if gen_config:
|
||||
metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
|
||||
metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
|
||||
metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
|
||||
metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
|
||||
metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
|
||||
metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
|
||||
metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
|
||||
metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
|
||||
metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
|
||||
metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
|
||||
metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
|
||||
metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
|
||||
|
||||
# Metadata Override File Provided
|
||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
||||
|
||||
metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
|
||||
metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
|
||||
metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
|
||||
metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
|
||||
metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
|
||||
metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
|
||||
metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
|
||||
metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
|
||||
metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
|
||||
metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
|
||||
metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
|
||||
metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
|
||||
|
||||
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
|
||||
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
||||
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
||||
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
|
||||
|
||||
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
|
||||
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
|
||||
|
||||
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
|
||||
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
|
||||
|
||||
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
|
||||
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
|
||||
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
|
||||
|
||||
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
|
||||
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
|
||||
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
|
||||
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
|
||||
|
||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
|
||||
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
|
||||
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
|
||||
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
|
||||
|
||||
# Base Models is received here as an array of models
|
||||
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
||||
|
||||
# Datasets is received here as an array of datasets
|
||||
metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
|
||||
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
metadata.name = model_name
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if metadata_override_path is None or not metadata_override_path.is_file():
|
||||
return {}
|
||||
|
||||
with open(metadata_override_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if model_path is None or not model_path.is_dir():
|
||||
return {}
|
||||
|
||||
model_card_path = model_path / "README.md"
|
||||
|
||||
if not model_card_path.is_file():
|
||||
return {}
|
||||
|
||||
# The model card metadata is assumed to always be in YAML (frontmatter)
|
||||
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
|
||||
yaml_content: str = ""
|
||||
with open(model_card_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.splitlines()
|
||||
lines_yaml = []
|
||||
if len(lines) == 0:
|
||||
# Empty file
|
||||
return {}
|
||||
if len(lines) > 0 and lines[0] != "---":
|
||||
# No frontmatter
|
||||
return {}
|
||||
for line in lines[1:]:
|
||||
if line == "---":
|
||||
break # End of frontmatter
|
||||
else:
|
||||
lines_yaml.append(line)
|
||||
yaml_content = "\n".join(lines_yaml) + "\n"
|
||||
|
||||
# Quick hack to fix the Norway problem
|
||||
# https://hitchdev.com/strictyaml/why/implicit-typing-removed/
|
||||
yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
|
||||
# yaml should use 2 spaces insted of tab
|
||||
# this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
|
||||
# (I've also sent a pr tp fix the modelcard too)
|
||||
yaml_content = yaml_content.replace("\t", " ")
|
||||
|
||||
if yaml_content:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
else:
|
||||
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if model_path is None or not model_path.is_dir():
|
||||
return {}
|
||||
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if not config_path.is_file():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if model_path is None or not model_path.is_dir():
|
||||
return {}
|
||||
|
||||
generation_config_path = model_path / "generation_config.json"
|
||||
|
||||
if not generation_config_path.is_file():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(generation_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# not all models have valid generation_config.json
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def id_to_title(string):
|
||||
# Convert capitalization into title form unless acronym or version number
|
||||
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
|
||||
# Huggingface often store model id as '<org>/<model name>'
|
||||
# so let's parse it and apply some heuristics if possible for model name components
|
||||
|
||||
if model_id is None:
|
||||
# model ID missing
|
||||
return None, None, None, None, None, None
|
||||
|
||||
if ' ' in model_id:
|
||||
# model ID is actually a normal human sentence
|
||||
# which means its most likely a normal model name only
|
||||
# not part of the hugging face naming standard, but whatever
|
||||
return model_id, None, None, None, None, None
|
||||
|
||||
if '/' in model_id:
|
||||
# model ID (huggingface style)
|
||||
org_component, model_full_name_component = model_id.split('/', 1)
|
||||
else:
|
||||
# model ID but missing org components
|
||||
org_component, model_full_name_component = None, model_id
|
||||
|
||||
# Check if we erroneously matched against './' or '../' etc...
|
||||
if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
|
||||
org_component = None
|
||||
|
||||
name_parts: list[str] = model_full_name_component.split('-')
|
||||
|
||||
# Remove empty parts
|
||||
for i in reversed(range(len(name_parts))):
|
||||
if len(name_parts[i]) == 0:
|
||||
del name_parts[i]
|
||||
|
||||
name_types: list[
|
||||
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
||||
] = [set() for _ in name_parts]
|
||||
|
||||
# Annotate the name
|
||||
for i, part in enumerate(name_parts):
|
||||
# Version
|
||||
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
|
||||
name_types[i].add("version")
|
||||
# Quant type (should not be there for base models, but still annotated)
|
||||
elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
|
||||
name_types[i].add("type")
|
||||
name_parts[i] = part.upper()
|
||||
# Model size
|
||||
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
|
||||
part = part.replace("_", ".")
|
||||
# Handle weird bloom-7b1 notation
|
||||
if part[-1].isdecimal():
|
||||
part = part[:-2] + "." + part[-1] + part[-2]
|
||||
# Normalize the size suffixes
|
||||
if len(part) > 1 and part[-2].isdecimal():
|
||||
if part[-1] in "kmbt":
|
||||
part = part[:-1] + part[-1].upper()
|
||||
if total_params != 0:
|
||||
try:
|
||||
label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
|
||||
# Only use it as a size label if it's close or bigger than the model size
|
||||
# Note that LoRA adapters don't necessarily include all layers,
|
||||
# so this is why bigger label sizes are accepted.
|
||||
# Do not use the size label when it's smaller than 1/8 of the model size
|
||||
if (total_params < 0 and label_params < abs(total_params) // 8) or (
|
||||
# Check both directions when the current model isn't a LoRA adapter
|
||||
total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
|
||||
):
|
||||
# Likely a context length
|
||||
name_types[i].add("finetune")
|
||||
# Lowercase the size when it's a context length
|
||||
part = part[:-1] + part[-1].lower()
|
||||
except ValueError:
|
||||
# Failed to convert the size label to float, use it anyway
|
||||
pass
|
||||
if len(name_types[i]) == 0:
|
||||
name_types[i].add("size_label")
|
||||
name_parts[i] = part
|
||||
# Some easy to recognize finetune names
|
||||
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
|
||||
if total_params < 0 and part.lower() == "lora":
|
||||
# ignore redundant "lora" in the finetune part when the output is a lora adapter
|
||||
name_types[i].add("type")
|
||||
else:
|
||||
name_types[i].add("finetune")
|
||||
|
||||
# Ignore word-based size labels when there is at least a number-based one present
|
||||
# TODO: should word-based size labels always be removed instead?
|
||||
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
|
||||
for n, t in zip(name_parts, name_types):
|
||||
if "size_label" in t:
|
||||
if all(c.isalpha() for c in n):
|
||||
t.remove("size_label")
|
||||
|
||||
at_start = True
|
||||
# Find the basename through the annotated name
|
||||
for part, t in zip(name_parts, name_types):
|
||||
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
|
||||
t.add("basename")
|
||||
else:
|
||||
if at_start:
|
||||
at_start = False
|
||||
if len(t) == 0:
|
||||
t.add("finetune")
|
||||
|
||||
# Remove the basename annotation from trailing version
|
||||
for part, t in zip(reversed(name_parts), reversed(name_types)):
|
||||
if "basename" in t and len(t) > 1:
|
||||
t.remove("basename")
|
||||
else:
|
||||
break
|
||||
|
||||
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
||||
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
|
||||
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
|
||||
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
||||
# TODO: should the basename version always be excluded?
|
||||
# NOTE: multiple finetune versions are joined together
|
||||
version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
|
||||
|
||||
if size_label is None and finetune is None and version is None:
|
||||
# Too ambiguous, output nothing
|
||||
basename = None
|
||||
|
||||
return model_full_name_component, org_component, basename, finetune, version, size_label
|
||||
|
||||
@staticmethod
|
||||
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
|
||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
|
||||
# Model Card Heuristics
|
||||
########################
|
||||
if model_card is not None:
|
||||
|
||||
def use_model_card_metadata(metadata_key: str, model_card_key: str):
|
||||
if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
|
||||
setattr(metadata, metadata_key, model_card.get(model_card_key))
|
||||
|
||||
def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
|
||||
# Note: Will append rather than replace if already exist
|
||||
tags_value = model_card.get(model_card_key, None)
|
||||
if tags_value is None:
|
||||
return
|
||||
|
||||
current_value = getattr(metadata, metadata_key, None)
|
||||
if current_value is None:
|
||||
current_value = []
|
||||
|
||||
if isinstance(tags_value, str):
|
||||
current_value.append(tags_value)
|
||||
elif isinstance(tags_value, list):
|
||||
current_value.extend(tags_value)
|
||||
|
||||
setattr(metadata, metadata_key, current_value)
|
||||
|
||||
# LLAMA.cpp's direct internal convention
|
||||
# (Definitely not part of hugging face formal/informal standard)
|
||||
#########################################
|
||||
use_model_card_metadata("name", "name")
|
||||
use_model_card_metadata("author", "author")
|
||||
use_model_card_metadata("version", "version")
|
||||
use_model_card_metadata("organization", "organization")
|
||||
use_model_card_metadata("description", "description")
|
||||
use_model_card_metadata("finetune", "finetune")
|
||||
use_model_card_metadata("basename", "basename")
|
||||
use_model_card_metadata("size_label", "size_label")
|
||||
use_model_card_metadata("source_url", "url")
|
||||
use_model_card_metadata("source_doi", "doi")
|
||||
use_model_card_metadata("source_uuid", "uuid")
|
||||
use_model_card_metadata("source_repo_url", "repo_url")
|
||||
|
||||
# LLAMA.cpp's huggingface style convention
|
||||
# (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
|
||||
###########################################
|
||||
use_model_card_metadata("name", "model_name")
|
||||
use_model_card_metadata("author", "model_author")
|
||||
use_model_card_metadata("version", "model_version")
|
||||
use_model_card_metadata("organization", "model_organization")
|
||||
use_model_card_metadata("description", "model_description")
|
||||
use_model_card_metadata("finetune", "model_finetune")
|
||||
use_model_card_metadata("basename", "model_basename")
|
||||
use_model_card_metadata("size_label", "model_size_label")
|
||||
use_model_card_metadata("source_url", "model_url")
|
||||
use_model_card_metadata("source_doi", "model_doi")
|
||||
use_model_card_metadata("source_uuid", "model_uuid")
|
||||
use_model_card_metadata("source_repo_url", "model_repo_url")
|
||||
|
||||
# Hugging Face Direct Convention
|
||||
#################################
|
||||
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
use_model_card_metadata("name", "model_name")
|
||||
use_model_card_metadata("author", "model_creator")
|
||||
use_model_card_metadata("basename", "model_type")
|
||||
|
||||
if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
|
||||
# This represents the parent models that this is based on
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
||||
metadata_base_models = []
|
||||
base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
|
||||
|
||||
if base_model_value is not None:
|
||||
if isinstance(base_model_value, str):
|
||||
metadata_base_models.append(base_model_value)
|
||||
elif isinstance(base_model_value, list):
|
||||
metadata_base_models.extend(base_model_value)
|
||||
|
||||
if metadata.base_models is None:
|
||||
metadata.base_models = []
|
||||
|
||||
for model_id in metadata_base_models:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
base_model = {}
|
||||
if isinstance(model_id, str):
|
||||
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
|
||||
base_model["repo_url"] = model_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in model_id:
|
||||
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
|
||||
if match:
|
||||
model_id_component = match.group(1)
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if org_component is not None and model_full_name_component is not None:
|
||||
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
|
||||
elif isinstance(model_id, dict):
|
||||
base_model = model_id
|
||||
|
||||
else:
|
||||
logger.error(f"base model entry '{str(model_id)}' not in a known format")
|
||||
|
||||
metadata.base_models.append(base_model)
|
||||
|
||||
if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
|
||||
# This represents the datasets that this was trained from
|
||||
metadata_datasets = []
|
||||
dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
|
||||
|
||||
if dataset_value is not None:
|
||||
if isinstance(dataset_value, str):
|
||||
metadata_datasets.append(dataset_value)
|
||||
elif isinstance(dataset_value, list):
|
||||
metadata_datasets.extend(dataset_value)
|
||||
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = []
|
||||
|
||||
for dataset_id in metadata_datasets:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
dataset = {}
|
||||
if isinstance(dataset_id, str):
|
||||
if dataset_id.startswith(("http://", "https://", "ssh://")):
|
||||
dataset["repo_url"] = dataset_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in dataset_id:
|
||||
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
|
||||
if match:
|
||||
dataset_id_component = match.group(1)
|
||||
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
if org_component is not None and dataset_name_component is not None:
|
||||
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
|
||||
|
||||
elif isinstance(dataset_id, dict):
|
||||
dataset = dataset_id
|
||||
|
||||
else:
|
||||
logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
|
||||
|
||||
metadata.datasets.append(dataset)
|
||||
|
||||
use_model_card_metadata("license", "license")
|
||||
use_model_card_metadata("license_name", "license_name")
|
||||
use_model_card_metadata("license_link", "license_link")
|
||||
|
||||
use_array_model_card_metadata("tags", "tags")
|
||||
use_array_model_card_metadata("tags", "pipeline_tag")
|
||||
|
||||
use_array_model_card_metadata("languages", "languages")
|
||||
use_array_model_card_metadata("languages", "language")
|
||||
|
||||
# Hugging Face Parameter Heuristics
|
||||
####################################
|
||||
|
||||
if hf_params is not None:
|
||||
|
||||
hf_name_or_path = hf_params.get("_name_or_path")
|
||||
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
|
||||
# Use _name_or_path only if its actually a model name and not some computer path
|
||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||
model_id = hf_name_or_path
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.size_label is None and size_label is not None:
|
||||
metadata.size_label = size_label
|
||||
|
||||
# Directory Folder Name Fallback Heuristics
|
||||
############################################
|
||||
if model_path is not None:
|
||||
model_id = model_path.name
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.size_label is None and size_label is not None:
|
||||
metadata.size_label = size_label
|
||||
|
||||
return metadata
|
||||
|
||||
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
||||
assert self.name is not None
|
||||
|
||||
if self.sampling_sequence is not None:
|
||||
gguf_writer.add_sampling_sequence(self.sampling_sequence)
|
||||
if self.sampling_top_k is not None:
|
||||
gguf_writer.add_sampling_top_k(self.sampling_top_k)
|
||||
if self.sampling_top_p is not None:
|
||||
gguf_writer.add_sampling_top_p(self.sampling_top_p)
|
||||
if self.sampling_min_p is not None:
|
||||
gguf_writer.add_sampling_min_p(self.sampling_min_p)
|
||||
if self.sampling_xtc_probability is not None:
|
||||
gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
|
||||
if self.sampling_xtc_threshold is not None:
|
||||
gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
|
||||
if self.sampling_temp is not None:
|
||||
gguf_writer.add_sampling_temp(self.sampling_temp)
|
||||
if self.sampling_penalty_last_n is not None:
|
||||
gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
|
||||
if self.sampling_penalty_repeat is not None:
|
||||
gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
|
||||
if self.sampling_mirostat is not None:
|
||||
gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
|
||||
if self.sampling_mirostat_tau is not None:
|
||||
gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
|
||||
if self.sampling_mirostat_eta is not None:
|
||||
gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
|
||||
|
||||
gguf_writer.add_name(self.name)
|
||||
|
||||
if self.author is not None:
|
||||
gguf_writer.add_author(self.author)
|
||||
if self.version is not None:
|
||||
gguf_writer.add_version(self.version)
|
||||
if self.organization is not None:
|
||||
gguf_writer.add_organization(self.organization)
|
||||
|
||||
if self.finetune is not None:
|
||||
gguf_writer.add_finetune(self.finetune)
|
||||
if self.basename is not None:
|
||||
gguf_writer.add_basename(self.basename)
|
||||
|
||||
if self.description is not None:
|
||||
gguf_writer.add_description(self.description)
|
||||
if self.quantized_by is not None:
|
||||
gguf_writer.add_quantized_by(self.quantized_by)
|
||||
|
||||
if self.size_label is not None:
|
||||
gguf_writer.add_size_label(self.size_label)
|
||||
|
||||
if self.license is not None:
|
||||
if isinstance(self.license, list):
|
||||
gguf_writer.add_license(",".join(self.license))
|
||||
else:
|
||||
gguf_writer.add_license(self.license)
|
||||
if self.license_name is not None:
|
||||
gguf_writer.add_license_name(self.license_name)
|
||||
if self.license_link is not None:
|
||||
gguf_writer.add_license_link(self.license_link)
|
||||
|
||||
if self.url is not None:
|
||||
gguf_writer.add_url(self.url)
|
||||
if self.doi is not None:
|
||||
gguf_writer.add_doi(self.doi)
|
||||
if self.uuid is not None:
|
||||
gguf_writer.add_uuid(self.uuid)
|
||||
if self.repo_url is not None:
|
||||
gguf_writer.add_repo_url(self.repo_url)
|
||||
|
||||
if self.source_url is not None:
|
||||
gguf_writer.add_source_url(self.source_url)
|
||||
if self.source_doi is not None:
|
||||
gguf_writer.add_source_doi(self.source_doi)
|
||||
if self.source_uuid is not None:
|
||||
gguf_writer.add_source_uuid(self.source_uuid)
|
||||
if self.source_repo_url is not None:
|
||||
gguf_writer.add_source_repo_url(self.source_repo_url)
|
||||
|
||||
if self.base_models is not None:
|
||||
gguf_writer.add_base_model_count(len(self.base_models))
|
||||
for key, base_model_entry in enumerate(self.base_models):
|
||||
if "name" in base_model_entry:
|
||||
gguf_writer.add_base_model_name(key, base_model_entry["name"])
|
||||
if "author" in base_model_entry:
|
||||
gguf_writer.add_base_model_author(key, base_model_entry["author"])
|
||||
if "version" in base_model_entry:
|
||||
gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
||||
if "organization" in base_model_entry:
|
||||
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
||||
if "description" in base_model_entry:
|
||||
gguf_writer.add_base_model_description(key, base_model_entry["description"])
|
||||
if "url" in base_model_entry:
|
||||
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
|
||||
if "uuid" in base_model_entry:
|
||||
gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
|
||||
if "repo_url" in base_model_entry:
|
||||
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_dataset_count(len(self.datasets))
|
||||
for key, dataset_entry in enumerate(self.datasets):
|
||||
if "name" in dataset_entry:
|
||||
gguf_writer.add_dataset_name(key, dataset_entry["name"])
|
||||
if "author" in dataset_entry:
|
||||
gguf_writer.add_dataset_author(key, dataset_entry["author"])
|
||||
if "version" in dataset_entry:
|
||||
gguf_writer.add_dataset_version(key, dataset_entry["version"])
|
||||
if "organization" in dataset_entry:
|
||||
gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
|
||||
if "description" in dataset_entry:
|
||||
gguf_writer.add_dataset_description(key, dataset_entry["description"])
|
||||
if "url" in dataset_entry:
|
||||
gguf_writer.add_dataset_url(key, dataset_entry["url"])
|
||||
if "doi" in dataset_entry:
|
||||
gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
|
||||
if "uuid" in dataset_entry:
|
||||
gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
|
||||
if "repo_url" in dataset_entry:
|
||||
gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
|
||||
|
||||
if self.tags is not None:
|
||||
gguf_writer.add_tags(self.tags)
|
||||
if self.languages is not None:
|
||||
gguf_writer.add_languages(self.languages)
|
||||
0
gguf-py/gguf/py.typed
Normal file
0
gguf-py/gguf/py.typed
Normal file
1318
gguf-py/gguf/quants.py
Normal file
1318
gguf-py/gguf/quants.py
Normal file
File diff suppressed because it is too large
Load Diff
186
gguf-py/gguf/scripts/gguf_convert_endian.py
Executable file
186
gguf-py/gguf/scripts/gguf_convert_endian.py
Executable file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
logger = logging.getLogger("gguf-convert-endian")
|
||||
|
||||
|
||||
def byteswap_noop(tensor, block_offs):
|
||||
# this function is used when byteswapping is not needed
|
||||
pass
|
||||
|
||||
|
||||
def byteswap_q4_0(tensor, block_offs):
|
||||
# Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
|
||||
|
||||
# Byte-Swap f16 sized delta field
|
||||
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
|
||||
def byteswap_q8_0(tensor, block_offs):
|
||||
# Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations.
|
||||
|
||||
# Byte-Swap f16 sized delta field
|
||||
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
|
||||
def byteswap_q4_k(tensor, block_offs):
|
||||
# Each block_q4_k consists of 2 f16 values followed by 140 int8 values.
|
||||
|
||||
# Byte-Swap f16 sized fields
|
||||
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
|
||||
def byteswap_q6_k(tensor, block_offs):
|
||||
# Each block_q6_k consists of 208 int8 values followed by 1 f16 value.
|
||||
|
||||
# Byte-Swap f16 sized field
|
||||
delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
|
||||
byteswap_tensors = {
|
||||
gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0,
|
||||
gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0,
|
||||
gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k,
|
||||
gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k,
|
||||
gguf.GGMLQuantizationType.MXFP4: byteswap_noop,
|
||||
}
|
||||
|
||||
|
||||
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
||||
file_endian = reader.endianess.name
|
||||
if reader.byte_order == 'S':
|
||||
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||
else:
|
||||
host_endian = file_endian
|
||||
order = host_endian if args.order == "native" else args.order.upper()
|
||||
logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian")
|
||||
if file_endian == order:
|
||||
logger.info(f"* File is already {order} endian. Nothing to do.")
|
||||
sys.exit(0)
|
||||
logger.info("* Checking tensors for conversion compatibility")
|
||||
for tensor in reader.tensors:
|
||||
if tensor.tensor_type not in byteswap_tensors and \
|
||||
tensor.tensor_type not in (
|
||||
gguf.GGMLQuantizationType.F32,
|
||||
gguf.GGMLQuantizationType.F16,
|
||||
gguf.GGMLQuantizationType.BF16,
|
||||
):
|
||||
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
||||
logger.info(f"* Preparing to convert from {file_endian} to {order}")
|
||||
if args.dry_run:
|
||||
return
|
||||
logger.warning("*** Warning *** Warning *** Warning **")
|
||||
logger.warning("* This conversion process may damage the file. Ensure you have a backup.")
|
||||
if order != host_endian:
|
||||
logger.warning("* Requested endian differs from host, you will not be able to load the model on this machine.")
|
||||
logger.warning("* The file will be modified immediately, so if conversion fails or is interrupted")
|
||||
logger.warning("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:")
|
||||
response = input("YES, I am sure> ")
|
||||
if response != "YES":
|
||||
logger.warning("You didn't enter YES. Okay then, see ya!")
|
||||
sys.exit(0)
|
||||
logger.info(f"* Converting fields ({len(reader.fields)})")
|
||||
for idx, field in enumerate(reader.fields.values()):
|
||||
logger.info(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}")
|
||||
for part in field.parts:
|
||||
part.byteswap(inplace=True)
|
||||
logger.info(f"* Converting tensors ({len(reader.tensors)})")
|
||||
|
||||
for idx, tensor in enumerate(pbar := tqdm(reader.tensors, desc="Converting tensor")):
|
||||
log_message = (
|
||||
f"Converting tensor {repr(tensor.name)}, "
|
||||
f"type={tensor.tensor_type.name}, "
|
||||
f"elements={tensor.n_elements} "
|
||||
)
|
||||
|
||||
# Byte-swap each part of the tensor's field
|
||||
for part in tensor.field.parts:
|
||||
part.byteswap(inplace=True)
|
||||
|
||||
# Byte-swap tensor data if necessary
|
||||
if tensor.tensor_type in byteswap_tensors:
|
||||
# first flatten structure
|
||||
oldshape = tensor.data.shape
|
||||
newshape = 1
|
||||
for i in tensor.data.shape:
|
||||
newshape *= i
|
||||
|
||||
tensor.data.resize(newshape)
|
||||
|
||||
block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1]
|
||||
byteswap_func = byteswap_tensors[tensor.tensor_type]
|
||||
|
||||
n_blocks = len(tensor.data) // block_size
|
||||
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
||||
block_offs = block_num * block_size
|
||||
|
||||
byteswap_func(tensor, block_offs)
|
||||
|
||||
if block_num % 100000 == 0:
|
||||
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
|
||||
|
||||
# restore old shape in case it's ever used
|
||||
tensor.data.resize(oldshape)
|
||||
elif tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
|
||||
# Special case for BF16
|
||||
# It is 2-bytes data, but by default view loads it as 1-byte data.
|
||||
# Change to correct view before byteswapping.
|
||||
tensor.data.view(dtype=np.uint16).byteswap(inplace=True)
|
||||
else:
|
||||
# Handle other tensor types
|
||||
tensor.data.byteswap(inplace=True)
|
||||
|
||||
pbar.set_description(log_message)
|
||||
|
||||
logger.info("* Completion")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Convert GGUF file byte order")
|
||||
parser.add_argument(
|
||||
"model", type=str,
|
||||
help="GGUF format model filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"order", type=str, choices=['big', 'little', 'native'],
|
||||
help="Requested byte order",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="Don't actually change anything",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
|
||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
logger.info(f'* Loading: {args.model}')
|
||||
reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
||||
convert_byteorder(reader, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
477
gguf-py/gguf/scripts/gguf_dump.py
Executable file
477
gguf-py/gguf/scripts/gguf_dump.py
Executable file
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402
|
||||
|
||||
logger = logging.getLogger("gguf-dump")
|
||||
|
||||
|
||||
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
||||
file_endian = reader.endianess.name
|
||||
if reader.byte_order == 'S':
|
||||
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||
else:
|
||||
host_endian = file_endian
|
||||
return (host_endian, file_endian)
|
||||
|
||||
|
||||
# For more information about what field.parts and field.data represent,
|
||||
# please see the comments in the modify_gguf.py example.
|
||||
def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
host_endian, file_endian = get_file_host_endian(reader)
|
||||
print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') # noqa: NP100
|
||||
print(f'* Dumping {len(reader.fields)} key/value pair(s)') # noqa: NP100
|
||||
for n, field in enumerate(reader.fields.values(), 1):
|
||||
if not field.types:
|
||||
pretty_type = 'N/A'
|
||||
elif field.types[0] == GGUFValueType.ARRAY:
|
||||
nest_count = len(field.types) - 1
|
||||
pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
|
||||
else:
|
||||
pretty_type = str(field.types[-1].name)
|
||||
|
||||
log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
|
||||
if field.types:
|
||||
curr_type = field.types[0]
|
||||
if curr_type == GGUFValueType.STRING:
|
||||
content = field.contents()
|
||||
if len(content) > 60:
|
||||
content = content[:57] + '...'
|
||||
log_message += ' = {0}'.format(repr(content))
|
||||
elif curr_type in reader.gguf_scalar_to_np:
|
||||
log_message += ' = {0}'.format(field.contents())
|
||||
else:
|
||||
content = repr(field.contents(slice(6)))
|
||||
if len(field.data) > 6:
|
||||
content = content[:-1] + ', ...]'
|
||||
log_message += ' = {0}'.format(content)
|
||||
print(log_message) # noqa: NP100
|
||||
if args.no_tensors:
|
||||
return
|
||||
print(f'* Dumping {len(reader.tensors)} tensor(s)') # noqa: NP100
|
||||
for n, tensor in enumerate(reader.tensors, 1):
|
||||
prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)))
|
||||
print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') # noqa: NP100
|
||||
|
||||
|
||||
def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
host_endian, file_endian = get_file_host_endian(reader)
|
||||
metadata: dict[str, Any] = {}
|
||||
tensors: dict[str, Any] = {}
|
||||
result = {
|
||||
"filename": args.model,
|
||||
"endian": file_endian,
|
||||
"metadata": metadata,
|
||||
"tensors": tensors,
|
||||
}
|
||||
for idx, field in enumerate(reader.fields.values()):
|
||||
curr: dict[str, Any] = {
|
||||
"index": idx,
|
||||
"type": field.types[0].name if field.types else 'UNKNOWN',
|
||||
"offset": field.offset,
|
||||
}
|
||||
metadata[field.name] = curr
|
||||
if field.types[:1] == [GGUFValueType.ARRAY]:
|
||||
curr["array_types"] = [t.name for t in field.types][1:]
|
||||
if not args.json_array:
|
||||
continue
|
||||
curr["value"] = field.contents()
|
||||
else:
|
||||
curr["value"] = field.contents()
|
||||
if not args.no_tensors:
|
||||
for idx, tensor in enumerate(reader.tensors):
|
||||
tensors[tensor.name] = {
|
||||
"index": idx,
|
||||
"shape": tensor.shape.tolist(),
|
||||
"type": tensor.tensor_type.name,
|
||||
"offset": tensor.field.offset,
|
||||
}
|
||||
json.dump(result, sys.stdout)
|
||||
|
||||
|
||||
def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]):
|
||||
# JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957
|
||||
|
||||
# Alignment Utility Function
|
||||
def strAlign(padding: int, alignMode: str | None, strVal: str):
|
||||
if alignMode == 'center':
|
||||
return strVal.center(padding)
|
||||
elif alignMode == 'right':
|
||||
return strVal.rjust(padding - 1) + ' '
|
||||
elif alignMode == 'left':
|
||||
return ' ' + strVal.ljust(padding - 1)
|
||||
else: # default left
|
||||
return ' ' + strVal.ljust(padding - 1)
|
||||
|
||||
def dashAlign(padding: int, alignMode: str | None):
|
||||
if alignMode == 'center':
|
||||
return ':' + '-' * (padding - 2) + ':'
|
||||
elif alignMode == 'right':
|
||||
return '-' * (padding - 1) + ':'
|
||||
elif alignMode == 'left':
|
||||
return ':' + '-' * (padding - 1)
|
||||
else: # default left
|
||||
return '-' * (padding)
|
||||
|
||||
# Calculate Padding For Each Column Based On Header and Data Length
|
||||
rowsPadding = {}
|
||||
for index, columnEntry in enumerate(header_map):
|
||||
padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2
|
||||
headerPadCount = len(columnEntry['header_name']) + 2
|
||||
rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount
|
||||
|
||||
# Render Markdown Header
|
||||
rows = []
|
||||
rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map)))
|
||||
rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map)))
|
||||
|
||||
# Render Tabular Data
|
||||
for item in data:
|
||||
rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map)))
|
||||
|
||||
# Convert Tabular String Rows Into String
|
||||
tableString = ""
|
||||
for row in rows:
|
||||
tableString += f'|{row}|\n'
|
||||
|
||||
return tableString
|
||||
|
||||
|
||||
def element_count_rounded_notation(count: int) -> str:
|
||||
if count > 1e15 :
|
||||
# Quadrillion
|
||||
scaled_amount = count * 1e-15
|
||||
scale_suffix = "Q"
|
||||
elif count > 1e12 :
|
||||
# Trillions
|
||||
scaled_amount = count * 1e-12
|
||||
scale_suffix = "T"
|
||||
elif count > 1e9 :
|
||||
# Billions
|
||||
scaled_amount = count * 1e-9
|
||||
scale_suffix = "B"
|
||||
elif count > 1e6 :
|
||||
# Millions
|
||||
scaled_amount = count * 1e-6
|
||||
scale_suffix = "M"
|
||||
elif count > 1e3 :
|
||||
# Thousands
|
||||
scaled_amount = count * 1e-3
|
||||
scale_suffix = "K"
|
||||
else:
|
||||
# Under Thousands
|
||||
scaled_amount = count
|
||||
scale_suffix = ""
|
||||
return f"{'~' if count > 1e3 else ''}{round(scaled_amount)}{scale_suffix}"
|
||||
|
||||
|
||||
def translate_tensor_name(name):
|
||||
words = name.split(".")
|
||||
|
||||
# Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names
|
||||
abbreviation_dictionary = {
|
||||
'token_embd': 'Token embedding',
|
||||
'pos_embd': 'Position embedding',
|
||||
'output_norm': 'Output normalization',
|
||||
'output': 'Output',
|
||||
'attn_norm': 'Attention normalization',
|
||||
'attn_norm_2': 'Attention normalization',
|
||||
'attn_qkv': 'Attention query-key-value',
|
||||
'attn_q': 'Attention query',
|
||||
'attn_k': 'Attention key',
|
||||
'attn_v': 'Attention value',
|
||||
'attn_output': 'Attention output',
|
||||
'ffn_norm': 'Feed-forward network normalization',
|
||||
'ffn_up': 'Feed-forward network "up"',
|
||||
'ffn_gate': 'Feed-forward network "gate"',
|
||||
'ffn_down': 'Feed-forward network "down"',
|
||||
'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models',
|
||||
'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models',
|
||||
'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models',
|
||||
'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models',
|
||||
'ssm_in': 'State space model input projections',
|
||||
'ssm_conv1d': 'State space model rolling/shift',
|
||||
'ssm_x': 'State space model selective parametrization',
|
||||
'ssm_a': 'State space model state compression',
|
||||
'ssm_d': 'State space model skip connection',
|
||||
'ssm_dt': 'State space model time step',
|
||||
'ssm_out': 'State space model output projection',
|
||||
'blk': 'Block',
|
||||
'enc': 'Encoder',
|
||||
'dec': 'Decoder',
|
||||
}
|
||||
|
||||
expanded_words = []
|
||||
for word in words:
|
||||
word_norm = word.strip().lower()
|
||||
if word_norm in abbreviation_dictionary:
|
||||
expanded_words.append(abbreviation_dictionary[word_norm].title())
|
||||
else:
|
||||
expanded_words.append(word.title())
|
||||
|
||||
return ' '.join(expanded_words)
|
||||
|
||||
|
||||
def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
host_endian, file_endian = get_file_host_endian(reader)
|
||||
markdown_content = ""
|
||||
markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n'
|
||||
markdown_content += f'- Endian: {file_endian} endian\n'
|
||||
markdown_content += '\n'
|
||||
markdown_content += '## Key Value Metadata Store\n\n'
|
||||
markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n'
|
||||
markdown_content += '\n'
|
||||
total_model_bytes = 0
|
||||
total_model_elements = 0
|
||||
|
||||
kv_dump_table: list[dict[str, str | int]] = []
|
||||
for n, field in enumerate(reader.fields.values(), 1):
|
||||
if not field.types:
|
||||
pretty_type = 'N/A'
|
||||
elif field.types[0] == GGUFValueType.ARRAY:
|
||||
nest_count = len(field.types) - 1
|
||||
pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
|
||||
else:
|
||||
pretty_type = str(field.types[-1].name)
|
||||
|
||||
def escape_markdown_inline_code(value_string):
|
||||
# Find the longest contiguous sequence of backticks in the string then
|
||||
# wrap string with appropriate number of backticks required to escape it
|
||||
max_backticks = max((len(match.group(0)) for match in re.finditer(r'`+', value_string)), default=0)
|
||||
inline_code_marker = '`' * (max_backticks + 1)
|
||||
|
||||
# If the string starts or ends with a backtick, add a space at the beginning and end
|
||||
if value_string.startswith('`') or value_string.endswith('`'):
|
||||
value_string = f" {value_string} "
|
||||
|
||||
return f"{inline_code_marker}{value_string}{inline_code_marker}"
|
||||
|
||||
total_elements = len(field.data)
|
||||
value = ""
|
||||
if len(field.types) == 1:
|
||||
curr_type = field.types[0]
|
||||
if curr_type == GGUFValueType.STRING:
|
||||
truncate_length = 60
|
||||
value_string = str(bytes(field.parts[-1]), encoding='utf-8')
|
||||
if len(value_string) > truncate_length:
|
||||
head = escape_markdown_inline_code(value_string[:truncate_length // 2])
|
||||
tail = escape_markdown_inline_code(value_string[-truncate_length // 2:])
|
||||
value = "{head}...{tail}".format(head=head, tail=tail)
|
||||
else:
|
||||
value = escape_markdown_inline_code(value_string)
|
||||
elif curr_type in reader.gguf_scalar_to_np:
|
||||
value = str(field.parts[-1][0])
|
||||
else:
|
||||
if field.types[0] == GGUFValueType.ARRAY:
|
||||
curr_type = field.types[1]
|
||||
array_elements = []
|
||||
|
||||
if curr_type == GGUFValueType.STRING:
|
||||
render_element = min(5, total_elements)
|
||||
for element_pos in range(render_element):
|
||||
truncate_length = 30
|
||||
value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8')
|
||||
if len(value_string) > truncate_length:
|
||||
head = escape_markdown_inline_code(value_string[:truncate_length // 2])
|
||||
tail = escape_markdown_inline_code(value_string[-truncate_length // 2:])
|
||||
value = "{head}...{tail}".format(head=head, tail=tail)
|
||||
else:
|
||||
value = escape_markdown_inline_code(value_string)
|
||||
array_elements.append(value)
|
||||
|
||||
elif curr_type in reader.gguf_scalar_to_np:
|
||||
render_element = min(7, total_elements)
|
||||
for element_pos in range(render_element):
|
||||
array_elements.append(str(field.parts[-1 - (total_elements - element_pos - 1)][0]))
|
||||
|
||||
value = f'[ {", ".join(array_elements).strip()}{", ..." if total_elements > len(array_elements) else ""} ]'
|
||||
|
||||
kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value})
|
||||
|
||||
kv_dump_table_header_map = [
|
||||
{'key_name':'n', 'header_name':'POS', 'align':'right'},
|
||||
{'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'},
|
||||
{'key_name':'total_elements', 'header_name':'Count', 'align':'right'},
|
||||
{'key_name':'field_name', 'header_name':'Key', 'align':'left'},
|
||||
{'key_name':'value', 'header_name':'Value', 'align':'left'},
|
||||
]
|
||||
|
||||
markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table)
|
||||
|
||||
markdown_content += "\n"
|
||||
|
||||
if not args.no_tensors:
|
||||
# Group tensors by their prefix and maintain order
|
||||
tensor_prefix_order: list[str] = []
|
||||
tensor_name_to_key: dict[str, int] = {}
|
||||
tensor_groups: dict[str, list[ReaderTensor]] = {}
|
||||
total_elements = sum(tensor.n_elements for tensor in reader.tensors)
|
||||
|
||||
# Parsing Tensors Record
|
||||
for key, tensor in enumerate(reader.tensors):
|
||||
tensor_components = tensor.name.split('.')
|
||||
|
||||
# Classify Tensor Group
|
||||
tensor_group_name = "base"
|
||||
if tensor_components[0] == 'blk':
|
||||
tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}"
|
||||
elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk':
|
||||
tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}"
|
||||
elif tensor_components[0] in ['enc', 'dec']:
|
||||
tensor_group_name = f"{tensor_components[0]}"
|
||||
|
||||
# Check if new Tensor Group
|
||||
if tensor_group_name not in tensor_groups:
|
||||
tensor_groups[tensor_group_name] = []
|
||||
tensor_prefix_order.append(tensor_group_name)
|
||||
|
||||
# Record Tensor and Tensor Position
|
||||
tensor_groups[tensor_group_name].append(tensor)
|
||||
tensor_name_to_key[tensor.name] = key
|
||||
|
||||
# Tensors Mapping Dump
|
||||
markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n'
|
||||
markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n'
|
||||
markdown_content += '\n'
|
||||
|
||||
for group in tensor_prefix_order:
|
||||
tensors = tensor_groups[group]
|
||||
group_elements = sum(tensor.n_elements for tensor in tensors)
|
||||
markdown_content += f"- [{translate_tensor_name(group)} Tensor Group - {element_count_rounded_notation(group_elements)} Elements](#{group.replace('.', '_')})\n"
|
||||
|
||||
markdown_content += "\n"
|
||||
|
||||
markdown_content += "### Tensor Data Offset\n"
|
||||
markdown_content += '\n'
|
||||
markdown_content += 'This table contains the offset and data segment relative to start of file\n'
|
||||
markdown_content += '\n'
|
||||
|
||||
tensor_mapping_table: list[dict[str, str | int]] = []
|
||||
for key, tensor in enumerate(reader.tensors):
|
||||
data_offset_pretty = '{0:#16x}'.format(tensor.data_offset)
|
||||
data_size_pretty = '{0:#16x}'.format(tensor.n_bytes)
|
||||
tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty})
|
||||
|
||||
tensors_mapping_table_header_map = [
|
||||
{'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
|
||||
{'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'},
|
||||
{'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'},
|
||||
{'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'},
|
||||
]
|
||||
|
||||
markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table)
|
||||
markdown_content += "\n"
|
||||
|
||||
for group in tensor_prefix_order:
|
||||
tensors = tensor_groups[group]
|
||||
group_elements = sum(tensor.n_elements for tensor in tensors)
|
||||
group_percentage = group_elements / total_elements * 100
|
||||
total_group_bytes = 0
|
||||
total_group_elements = 0
|
||||
markdown_content += f"### <a name=\"{group.replace('.', '_')}\">{translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements</a>\n\n"
|
||||
|
||||
# Precalculate column sizing for visual consistency
|
||||
prettify_element_est_count_size: int = 1
|
||||
prettify_element_count_size: int = 1
|
||||
prettify_dimension_max_widths: dict[int, int] = {}
|
||||
for tensor in tensors:
|
||||
prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements))))
|
||||
prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements)))
|
||||
for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))):
|
||||
prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size)))
|
||||
|
||||
# Generate Tensor Layer Table Content
|
||||
tensor_dump_table: list[dict[str, str | int]] = []
|
||||
for tensor in tensors:
|
||||
human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)"))
|
||||
pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))))
|
||||
element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})"
|
||||
element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}"
|
||||
type_name_string = f"{tensor.tensor_type.name}"
|
||||
if tensor.n_elements > 0:
|
||||
bpw = (tensor.n_bytes * 8) / tensor.n_elements
|
||||
else:
|
||||
bpw = float('nan')
|
||||
tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string, "bpw": f"{bpw:.4f}"})
|
||||
total_group_bytes += tensor.n_bytes
|
||||
total_group_elements += tensor.n_elements
|
||||
|
||||
tensor_dump_table_header_map = [
|
||||
{'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
|
||||
{'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'},
|
||||
{'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'},
|
||||
{'key_name':'element_count', 'header_name':'Elements', 'align':'left'},
|
||||
{'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'},
|
||||
{'key_name':'tensor_type', 'header_name':'Type', 'align':'left'},
|
||||
{'key_name':'bpw', 'header_name':'BPW', 'align':'right'},
|
||||
]
|
||||
|
||||
markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table)
|
||||
|
||||
markdown_content += "\n"
|
||||
markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n"
|
||||
markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n"
|
||||
if total_group_elements > 0:
|
||||
total_group_bpw = (total_group_bytes * 8) / total_group_elements
|
||||
markdown_content += f"- Bits per Weight (BPW) for {group}: {total_group_bpw:.4f} bits\n"
|
||||
else:
|
||||
markdown_content += f"- Bits per Weight (BPW) for {group}: undefined (no elements)\n"
|
||||
markdown_content += "\n\n"
|
||||
total_model_bytes += total_group_bytes
|
||||
total_model_elements += total_group_elements
|
||||
|
||||
if total_model_elements > 0:
|
||||
total_model_bpw = (total_model_bytes * 8) / total_model_elements
|
||||
markdown_content += f"Total BPW for {os.path.basename(args.model)}: {total_model_bpw:.4f} bits"
|
||||
else:
|
||||
markdown_content += f"Total BPW for {os.path.basename(args.model)}: undefined (no elements)"
|
||||
print(markdown_content) # noqa: NP100
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
|
||||
parser.add_argument("model", type=str, help="GGUF format model filename")
|
||||
parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
|
||||
parser.add_argument("--json", action="store_true", help="Produce JSON output")
|
||||
parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
|
||||
parser.add_argument("--data-offset", action="store_true", help="Start of data offset")
|
||||
parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field")
|
||||
parser.add_argument("--markdown", action="store_true", help="Produce markdown output")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
|
||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
if not args.json and not args.markdown and not args.data_offset and not args.data_alignment:
|
||||
logger.info(f'* Loading: {args.model}')
|
||||
|
||||
reader = GGUFReader(args.model, 'r')
|
||||
|
||||
if args.json:
|
||||
dump_metadata_json(reader, args)
|
||||
elif args.markdown:
|
||||
dump_markdown_metadata(reader, args)
|
||||
elif args.data_offset:
|
||||
print(reader.data_offset) # noqa: NP100
|
||||
elif args.data_alignment:
|
||||
print(reader.alignment) # noqa: NP100
|
||||
else:
|
||||
dump_metadata(reader, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
File diff suppressed because it is too large
Load Diff
102
gguf-py/gguf/scripts/gguf_hash.py
Executable file
102
gguf-py/gguf/scripts/gguf_hash.py
Executable file
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import hashlib
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
|
||||
logger = logging.getLogger("gguf-hash")
|
||||
|
||||
# UUID_NAMESPACE_LLAMA_CPP = uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
|
||||
UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5')
|
||||
|
||||
|
||||
# For more information about what field.parts and field.data represent,
|
||||
# please see the comments in the modify_gguf.py example.
|
||||
def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None:
|
||||
sha1 = hashlib.sha1()
|
||||
sha256 = hashlib.sha256()
|
||||
uuidv5_sha1 = hashlib.sha1()
|
||||
uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
|
||||
|
||||
# Total Weight Calculation For Progress Bar
|
||||
total_weights = 0
|
||||
for n, tensor in enumerate(reader.tensors, 1):
|
||||
|
||||
# We don't need these
|
||||
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Calculate Tensor Volume
|
||||
sum_weights_in_tensor = 1
|
||||
for dim in tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
total_weights += sum_weights_in_tensor
|
||||
|
||||
# Hash Progress Bar
|
||||
bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
|
||||
|
||||
# Hashing Process
|
||||
for tensor in reader.tensors:
|
||||
|
||||
# We don't need these
|
||||
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Progressbar
|
||||
sum_weights_in_tensor = 1
|
||||
for dim in tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
bar.update(sum_weights_in_tensor)
|
||||
|
||||
if not no_layer:
|
||||
|
||||
sha1_layer = hashlib.sha1()
|
||||
sha1_layer.update(tensor.data.data)
|
||||
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||
|
||||
sha256_layer = hashlib.sha256()
|
||||
sha256_layer.update(tensor.data.data)
|
||||
print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||
|
||||
sha1.update(tensor.data.data)
|
||||
sha256.update(tensor.data.data)
|
||||
uuidv5_sha1.update(tensor.data.data)
|
||||
|
||||
# Flush Hash Progress Bar
|
||||
bar.close()
|
||||
|
||||
# Display Hash Output
|
||||
print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
|
||||
print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100
|
||||
print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
|
||||
parser.add_argument("model", type=str, help="GGUF format model filename")
|
||||
parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
|
||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
reader = GGUFReader(args.model, 'r')
|
||||
gguf_hash(reader, args.model, not args.progressbar, args.no_layer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
216
gguf-py/gguf/scripts/gguf_new_metadata.py
Executable file
216
gguf-py/gguf/scripts/gguf_new_metadata.py
Executable file
@@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import Any, Sequence, NamedTuple
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
logger = logging.getLogger("gguf-new-metadata")
|
||||
|
||||
|
||||
class MetadataDetails(NamedTuple):
|
||||
type: gguf.GGUFValueType
|
||||
value: Any
|
||||
description: str = ''
|
||||
sub_type: gguf.GGUFValueType | None = None
|
||||
|
||||
|
||||
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||
field = reader.get_field(key)
|
||||
|
||||
return field.contents() if field else None
|
||||
|
||||
|
||||
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
|
||||
token_ids = [index for index, value in enumerate(token_list) if value == token]
|
||||
|
||||
if len(token_ids) == 0:
|
||||
raise LookupError(f'Unable to find "{token}" in token list!')
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
|
||||
for field in reader.fields.values():
|
||||
# Suppress virtual fields and fields written by GGUFWriter
|
||||
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
||||
logger.debug(f'Suppressing {field.name}')
|
||||
continue
|
||||
|
||||
# Skip old chat templates if we have new ones
|
||||
if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||
logger.debug(f'Skipping {field.name}')
|
||||
continue
|
||||
|
||||
if field.name in remove_metadata:
|
||||
logger.debug(f'Removing {field.name}')
|
||||
continue
|
||||
|
||||
val_type = field.types[0]
|
||||
sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
|
||||
old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
|
||||
val = new_metadata.get(field.name, old_val)
|
||||
|
||||
if field.name in new_metadata:
|
||||
logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
|
||||
del new_metadata[field.name]
|
||||
elif val.value is not None:
|
||||
logger.debug(f'Copying {field.name}')
|
||||
|
||||
if val.value is not None:
|
||||
writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
|
||||
|
||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||
logger.debug('Adding chat template(s)')
|
||||
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
|
||||
del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
|
||||
|
||||
for key, val in new_metadata.items():
|
||||
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
|
||||
writer.add_key_value(key, val.value, val.type)
|
||||
|
||||
total_bytes = 0
|
||||
|
||||
for tensor in reader.tensors:
|
||||
total_bytes += tensor.n_bytes
|
||||
writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
||||
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
writer.write_header_to_file()
|
||||
writer.write_kv_data_to_file()
|
||||
writer.write_ti_data_to_file()
|
||||
|
||||
for tensor in reader.tensors:
|
||||
writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
|
||||
bar.update(tensor.n_bytes)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
|
||||
token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
|
||||
parser.add_argument("input", type=Path, help="GGUF format model input filename")
|
||||
parser.add_argument("output", type=Path, help="GGUF format model output filename")
|
||||
parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"')
|
||||
parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
|
||||
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
|
||||
parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
|
||||
parser.add_argument("--chat-template-file", type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja')
|
||||
parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"')
|
||||
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
|
||||
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
|
||||
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
|
||||
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
|
||||
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
||||
args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
new_metadata = {}
|
||||
remove_metadata = args.remove_metadata or []
|
||||
|
||||
if args.general_name:
|
||||
new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
|
||||
|
||||
if args.general_description:
|
||||
new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
|
||||
|
||||
if args.chat_template:
|
||||
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
|
||||
|
||||
if args.chat_template_config:
|
||||
with open(args.chat_template_config, 'r', encoding='utf-8') as fp:
|
||||
config = json.load(fp)
|
||||
template = config.get('chat_template')
|
||||
if template:
|
||||
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
|
||||
|
||||
if args.chat_template_file:
|
||||
with open(args.chat_template_file, 'r', encoding='utf-8') as fp:
|
||||
template = fp.read()
|
||||
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
|
||||
|
||||
if args.pre_tokenizer:
|
||||
new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer)
|
||||
|
||||
if remove_metadata:
|
||||
logger.warning('*** Warning *** Warning *** Warning **')
|
||||
logger.warning('* Most metadata is required for a fully functional GGUF file,')
|
||||
logger.warning('* removing crucial metadata may result in a corrupt output file!')
|
||||
|
||||
if not args.force:
|
||||
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
|
||||
response = input('YES, I am sure> ')
|
||||
if response != 'YES':
|
||||
logger.info("You didn't enter YES. Okay then, see ya!")
|
||||
sys.exit(0)
|
||||
|
||||
logger.info(f'* Loading: {args.input}')
|
||||
reader = gguf.GGUFReader(args.input, 'r')
|
||||
|
||||
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
|
||||
|
||||
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
|
||||
|
||||
for name, token in args.special_token or []:
|
||||
if name not in token_names:
|
||||
logger.warning(f'Unknown special token "{name}", ignoring...')
|
||||
else:
|
||||
ids = find_token(token_list, token)
|
||||
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
|
||||
|
||||
if len(ids) > 1:
|
||||
logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
|
||||
logger.warning(', '.join(str(i) for i in ids))
|
||||
|
||||
for name, id_string in args.special_token_by_id or []:
|
||||
if name not in token_names:
|
||||
logger.warning(f'Unknown special token "{name}", ignoring...')
|
||||
elif not id_string.isdecimal():
|
||||
raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
|
||||
else:
|
||||
id_int = int(id_string)
|
||||
|
||||
if id_int >= 0 and id_int < len(token_list):
|
||||
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
|
||||
else:
|
||||
raise LookupError(f'Token ID {id_int} is not within token list!')
|
||||
|
||||
if os.path.isfile(args.output) and not args.force:
|
||||
logger.warning('*** Warning *** Warning *** Warning **')
|
||||
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
|
||||
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
|
||||
response = input('YES, I am sure> ')
|
||||
if response != 'YES':
|
||||
logger.info("You didn't enter YES. Okay then, see ya!")
|
||||
sys.exit(0)
|
||||
|
||||
logger.info(f'* Writing: {args.output}')
|
||||
writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
|
||||
|
||||
alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
|
||||
if alignment is not None:
|
||||
logger.debug(f'Setting custom alignment: {alignment}')
|
||||
writer.data_alignment = alignment
|
||||
|
||||
copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
95
gguf-py/gguf/scripts/gguf_set_metadata.py
Executable file
95
gguf-py/gguf/scripts/gguf_set_metadata.py
Executable file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
logger = logging.getLogger("gguf-set-metadata")
|
||||
|
||||
|
||||
def minimal_example(filename: str) -> None:
|
||||
reader = GGUFReader(filename, 'r+')
|
||||
field = reader.fields['tokenizer.ggml.bos_token_id']
|
||||
if field is None:
|
||||
return
|
||||
part_index = field.data[0]
|
||||
field.parts[part_index][0] = 2 # Set tokenizer.ggml.bos_token_id to 2
|
||||
#
|
||||
# So what's this field.data thing? It's helpful because field.parts contains
|
||||
# _every_ part of the GGUF field. For example, tokenizer.ggml.bos_token_id consists
|
||||
# of:
|
||||
#
|
||||
# Part index 0: Key length (27)
|
||||
# Part index 1: Key data ("tokenizer.ggml.bos_token_id")
|
||||
# Part index 2: Field type (4, the id for GGUFValueType.UINT32)
|
||||
# Part index 3: Field value
|
||||
#
|
||||
# Note also that each part is an NDArray slice, so even a part that
|
||||
# is only a single value like the key length will be a NDArray of
|
||||
# the key length type (numpy.uint32).
|
||||
#
|
||||
# The .data attribute in the Field is a list of relevant part indexes
|
||||
# and doesn't contain internal GGUF details like the key length part.
|
||||
# In this case, .data will be [3] - just the part index of the
|
||||
# field value itself.
|
||||
|
||||
|
||||
def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
field = reader.get_field(args.key)
|
||||
if field is None:
|
||||
logger.error(f'! Field {repr(args.key)} not found')
|
||||
sys.exit(1)
|
||||
# Note that field.types is a list of types. This is because the GGUF
|
||||
# format supports arrays. For example, an array of UINT32 would
|
||||
# look like [GGUFValueType.ARRAY, GGUFValueType.UINT32]
|
||||
handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None
|
||||
if handler is None:
|
||||
logger.error(f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}')
|
||||
sys.exit(1)
|
||||
current_value = field.parts[field.data[0]][0]
|
||||
new_value = handler(args.value)
|
||||
logger.info(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}')
|
||||
if current_value == new_value:
|
||||
logger.info(f'- Key {repr(args.key)} already set to requested value {current_value}')
|
||||
sys.exit(0)
|
||||
if args.dry_run:
|
||||
sys.exit(0)
|
||||
if not args.force:
|
||||
logger.warning('*** Warning *** Warning *** Warning **')
|
||||
logger.warning('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.')
|
||||
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
|
||||
response = input('YES, I am sure> ')
|
||||
if response != 'YES':
|
||||
logger.info("You didn't enter YES. Okay then, see ya!")
|
||||
sys.exit(0)
|
||||
field.parts[field.data[0]][0] = new_value
|
||||
logger.info('* Field changed. Successful completion.')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata")
|
||||
parser.add_argument("model", type=str, help="GGUF format model filename")
|
||||
parser.add_argument("key", type=str, help="Metadata key to set")
|
||||
parser.add_argument("value", type=str, help="Metadata value to set")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything")
|
||||
parser.add_argument("--force", action="store_true", help="Change the field without confirmation")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
|
||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
logger.info(f'* Loading: {args.model}')
|
||||
reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
||||
set_metadata(reader, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
1786
gguf-py/gguf/tensor_mapping.py
Normal file
1786
gguf-py/gguf/tensor_mapping.py
Normal file
File diff suppressed because it is too large
Load Diff
340
gguf-py/gguf/utility.py
Normal file
340
gguf-py/gguf/utility.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||
ftype_lowercase: str = output_type.lower() if output_type is not None else ""
|
||||
ftype_uppercase: str = output_type.upper() if output_type is not None else ""
|
||||
return filename.format(ftype_lowercase,
|
||||
outtype=ftype_lowercase, ftype=ftype_lowercase,
|
||||
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
|
||||
|
||||
|
||||
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
|
||||
if model_params_count > 1e12 :
|
||||
# Trillions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-12
|
||||
scale_suffix = "T"
|
||||
elif model_params_count > 1e9 :
|
||||
# Billions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-9
|
||||
scale_suffix = "B"
|
||||
elif model_params_count > 1e6 :
|
||||
# Millions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-6
|
||||
scale_suffix = "M"
|
||||
else:
|
||||
# Thousands Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-3
|
||||
scale_suffix = "K"
|
||||
|
||||
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
|
||||
|
||||
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
|
||||
|
||||
|
||||
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
|
||||
|
||||
if expert_count > 0:
|
||||
pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
|
||||
size_class = f"{expert_count}x{pretty_size}"
|
||||
else:
|
||||
size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
|
||||
|
||||
return size_class
|
||||
|
||||
|
||||
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
|
||||
# Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
|
||||
if base_name is not None:
|
||||
name = base_name.strip().replace(' ', '-').replace('/', '-')
|
||||
elif model_name is not None:
|
||||
name = model_name.strip().replace(' ', '-').replace('/', '-')
|
||||
else:
|
||||
name = "ggml-model"
|
||||
|
||||
parameters = f"-{size_label}" if size_label is not None else ""
|
||||
|
||||
finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
|
||||
|
||||
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
|
||||
|
||||
encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
|
||||
|
||||
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
||||
|
||||
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteTensor:
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
offset_start: int
|
||||
size: int
|
||||
url: str
|
||||
|
||||
def data(self) -> bytearray:
|
||||
# TODO: handle request errors (maybe with limited retries?)
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
||||
return data
|
||||
|
||||
|
||||
class SafetensorRemote:
|
||||
"""
|
||||
Uility class to handle remote safetensor files.
|
||||
This class is designed to work with Hugging Face model repositories.
|
||||
|
||||
Example (one model has single safetensor file, the other has multiple):
|
||||
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
print(tensors)
|
||||
|
||||
Example reading tensor data:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
for name, meta in tensors.items():
|
||||
dtype, shape, offset_start, size, remote_safetensor_url = meta
|
||||
# read the tensor data
|
||||
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
|
||||
print(data)
|
||||
"""
|
||||
|
||||
BASE_DOMAIN = "https://huggingface.co"
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a Hugging Face model repository.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
|
||||
"""
|
||||
# case 1: model has only one single model.safetensor file
|
||||
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
|
||||
if is_single_file:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
|
||||
return cls.get_list_tensors(url)
|
||||
|
||||
# case 2: model has multiple files
|
||||
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
|
||||
is_multiple_files = cls.check_file_exist(index_url)
|
||||
if is_multiple_files:
|
||||
# read the index file
|
||||
index_data = cls.get_data_by_range(index_url, 0)
|
||||
index_str = index_data.decode('utf-8')
|
||||
index_json = json.loads(index_str)
|
||||
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
|
||||
weight_map = index_json["weight_map"]
|
||||
# get the list of files
|
||||
all_files = list(set(weight_map.values()))
|
||||
all_files.sort() # make sure we load shard files in order
|
||||
# get the list of tensors
|
||||
tensors: dict[str, RemoteTensor] = {}
|
||||
for file in all_files:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
|
||||
for key, val in cls.get_list_tensors(url).items():
|
||||
tensors[key] = val
|
||||
return tensors
|
||||
|
||||
raise ValueError(
|
||||
f"No safetensor file has been found for model {model_id}."
|
||||
"If the repo has safetensor files, make sure the model is public or you have a "
|
||||
"valid Hugging Face token set in the environment variable HF_TOKEN."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a remote safetensor file.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
|
||||
"""
|
||||
metadata, data_start_offset = cls.get_metadata(url)
|
||||
res: dict[str, RemoteTensor] = {}
|
||||
|
||||
for name, meta in metadata.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(meta, dict):
|
||||
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
|
||||
try:
|
||||
dtype = meta["dtype"]
|
||||
shape = meta["shape"]
|
||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||
size = offset_end_relative - offset_start_relative
|
||||
offset_start = data_start_offset + offset_start_relative
|
||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
# order by name (same as default safetensors behavior)
|
||||
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||
res = dict(sorted(res.items(), key=lambda t: t[0]))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls, url: str) -> tuple[dict, int]:
|
||||
"""
|
||||
Get JSON metadata from a remote safetensor file.
|
||||
|
||||
Returns tuple of (metadata, data_start_offset)
|
||||
"""
|
||||
# Request first 5MB of the file (hopefully enough for metadata)
|
||||
read_size = 5 * 1024 * 1024
|
||||
raw_data = cls.get_data_by_range(url, 0, read_size)
|
||||
|
||||
# Parse header
|
||||
# First 8 bytes contain the metadata length as u64 little-endian
|
||||
if len(raw_data) < 8:
|
||||
raise ValueError("Not enough data to read metadata size")
|
||||
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
|
||||
|
||||
# Calculate the data start offset
|
||||
data_start_offset = 8 + metadata_length
|
||||
|
||||
# Check if we have enough data to read the metadata
|
||||
if len(raw_data) < 8 + metadata_length:
|
||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
|
||||
|
||||
# Extract metadata bytes and parse as JSON
|
||||
metadata_bytes = raw_data[8:8 + metadata_length]
|
||||
metadata_str = metadata_bytes.decode('utf-8')
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
return metadata, data_start_offset
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
||||
|
||||
@classmethod
|
||||
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
||||
"""
|
||||
Get raw byte data from a remote file by range.
|
||||
If size is not specified, it will read the entire file.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
headers = cls._get_request_headers()
|
||||
if size > -1:
|
||||
headers["Range"] = f"bytes={start}-{start + size}"
|
||||
response = requests.get(url, allow_redirects=True, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Get raw byte data
|
||||
return response.content[slice(size if size > -1 else None)]
|
||||
|
||||
@classmethod
|
||||
def check_file_exist(cls, url: str) -> bool:
|
||||
"""
|
||||
Check if a file exists at the given URL.
|
||||
Returns True if the file exists, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
try:
|
||||
headers = cls._get_request_headers()
|
||||
headers["Range"] = "bytes=0-0"
|
||||
response = requests.head(url, allow_redirects=True, headers=headers)
|
||||
# Success (2xx) or redirect (3xx)
|
||||
return 200 <= response.status_code < 400
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_request_headers(cls) -> dict[str, str]:
|
||||
"""Prepare common headers for requests."""
|
||||
headers = {"User-Agent": "convert_hf_to_gguf"}
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||
return headers
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalTensorRange:
|
||||
filename: Path
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalTensor:
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
data_range: LocalTensorRange
|
||||
|
||||
def mmap_bytes(self) -> np.ndarray:
|
||||
return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
|
||||
|
||||
|
||||
class SafetensorsLocal:
|
||||
"""
|
||||
Read a safetensors file from the local filesystem.
|
||||
|
||||
Custom parsing gives a bit more control over the memory usage.
|
||||
The official safetensors library doesn't expose file ranges.
|
||||
"""
|
||||
|
||||
tensors: dict[str, LocalTensor]
|
||||
|
||||
def __init__(self, filename: Path):
|
||||
with open(filename, "rb") as f:
|
||||
metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
||||
file_size = os.stat(filename).st_size
|
||||
if file_size < 8 + metadata_length:
|
||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
||||
|
||||
metadata_str = f.read(metadata_length).decode('utf-8')
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
|
||||
|
||||
data_start_offset = f.tell()
|
||||
|
||||
tensors: dict[str, LocalTensor] = {}
|
||||
for name, meta in metadata.items():
|
||||
if name == "__metadata__":
|
||||
# ignore metadata, it's not a tensor
|
||||
continue
|
||||
|
||||
tensors[name] = LocalTensor(
|
||||
dtype=meta["dtype"],
|
||||
shape=tuple(meta["shape"]),
|
||||
data_range=LocalTensorRange(
|
||||
filename,
|
||||
data_start_offset + meta["data_offsets"][0],
|
||||
meta["data_offsets"][1] - meta["data_offsets"][0],
|
||||
),
|
||||
)
|
||||
|
||||
# order by name (same as default safetensors behavior)
|
||||
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
del args, kwargs # unused
|
||||
return self.tensors
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
del args, kwargs # unused
|
||||
891
gguf-py/gguf/vocab.py
Normal file
891
gguf-py/gguf/vocab.py
Normal file
@@ -0,0 +1,891 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
|
||||
|
||||
try:
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
_mistral_common_installed = False
|
||||
MistralTokenizer = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_filter_valid_tokenizer_files = None
|
||||
else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
get_one_valid_tokenizer_file,
|
||||
)
|
||||
except ImportError:
|
||||
# We still want the conversion to work with older mistral-common versions.
|
||||
get_one_valid_tokenizer_file = None
|
||||
|
||||
|
||||
import gguf
|
||||
|
||||
from .gguf_writer import GGUFWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecialVocab:
|
||||
merges: list[str]
|
||||
add_special_token: dict[str, bool]
|
||||
special_token_ids: dict[str, int]
|
||||
chat_template: str | Sequence[Mapping[str, str]] | None
|
||||
|
||||
def __init__(
|
||||
self, path: str | os.PathLike[str], load_merges: bool = False,
|
||||
special_token_types: Iterable[str] | None = None,
|
||||
n_vocab: int | None = None,
|
||||
):
|
||||
self.special_token_ids = {}
|
||||
self.add_special_token = {}
|
||||
self.n_vocab = n_vocab
|
||||
self.load_merges = load_merges
|
||||
self.merges = []
|
||||
self.chat_template = None
|
||||
if special_token_types is not None:
|
||||
self.special_token_types = special_token_types
|
||||
else:
|
||||
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
|
||||
self._load(Path(path))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
|
||||
len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
|
||||
)
|
||||
|
||||
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
|
||||
if self.merges:
|
||||
if not quiet:
|
||||
logger.info(f'Adding {len(self.merges)} merge(s).')
|
||||
gw.add_token_merges(self.merges)
|
||||
elif self.load_merges:
|
||||
logger.warning('Adding merges requested but no merges found, output may be non-functional.')
|
||||
for typ, tokid in self.special_token_ids.items():
|
||||
id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
|
||||
if id_handler is None:
|
||||
logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
|
||||
continue
|
||||
if not quiet:
|
||||
logger.info(f'Setting special token type {typ} to {tokid}')
|
||||
id_handler(tokid)
|
||||
for typ, value in self.add_special_token.items():
|
||||
add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
|
||||
if add_handler is None:
|
||||
logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
|
||||
continue
|
||||
if not quiet:
|
||||
logger.info(f'Setting add_{typ}_token to {value}')
|
||||
add_handler(value)
|
||||
if self.chat_template is not None:
|
||||
if not quiet:
|
||||
logger.info(f'Setting chat_template to {self.chat_template}')
|
||||
gw.add_chat_template(self.chat_template)
|
||||
|
||||
def _load(self, path: Path) -> None:
|
||||
self._try_load_from_tokenizer_json(path)
|
||||
self._try_load_from_config_json(path)
|
||||
if self.load_merges and not self.merges:
|
||||
self._try_load_merges_txt(path)
|
||||
|
||||
def _try_load_merges_txt(self, path: Path) -> bool:
|
||||
merges_file = path / 'merges.txt'
|
||||
if not merges_file.is_file():
|
||||
return False
|
||||
with open(merges_file, 'r', encoding = 'utf-8') as fp:
|
||||
first_line = next(fp, '').strip()
|
||||
if not first_line.startswith('#'):
|
||||
fp.seek(0)
|
||||
line_num = 0
|
||||
else:
|
||||
line_num = 1
|
||||
merges = []
|
||||
for line in fp:
|
||||
line_num += 1
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(None, 3)
|
||||
if len(parts) != 2:
|
||||
logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
|
||||
continue
|
||||
merges.append(f'{parts[0]} {parts[1]}')
|
||||
self.merges = merges
|
||||
return True
|
||||
|
||||
def _set_special_token(self, typ: str, tid: Any) -> None:
|
||||
if not isinstance(tid, int):
|
||||
return
|
||||
if tid < 0:
|
||||
raise ValueError(f'invalid value for special token type {typ}: {tid}')
|
||||
if self.n_vocab is None or tid < self.n_vocab:
|
||||
if typ in self.special_token_ids:
|
||||
return
|
||||
self.special_token_ids[typ] = tid
|
||||
return
|
||||
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
||||
|
||||
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
||||
tokenizer = None
|
||||
tokenizer_file = path / 'tokenizer.json'
|
||||
if tokenizer_file.is_file():
|
||||
with open(tokenizer_file, encoding = 'utf-8') as f:
|
||||
tokenizer = json.load(f)
|
||||
if self.load_merges:
|
||||
merges = tokenizer.get('model', {}).get('merges')
|
||||
if isinstance(merges, list) and merges:
|
||||
if isinstance(merges[0], str):
|
||||
self.merges = merges
|
||||
elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
|
||||
# New format since transformers 4.45 to support spaces in merges
|
||||
# ref: https://github.com/ggml-org/llama.cpp/issues/9692
|
||||
# TODO: internally store as the new format instead of converting to old
|
||||
if any(' ' in s for pair in merges for s in pair):
|
||||
logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
|
||||
self.merges = [
|
||||
' '.join(
|
||||
[
|
||||
# ensure the spaces are properly encoded
|
||||
''.join(
|
||||
chr(ord(c) + 256) if c == ' ' else c
|
||||
for c in part
|
||||
)
|
||||
for part in pair
|
||||
]
|
||||
)
|
||||
for pair in merges
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unknown tokenizer merges format")
|
||||
added_tokens = tokenizer.get('added_tokens', {})
|
||||
else:
|
||||
added_tokens = {}
|
||||
tokenizer_config = None
|
||||
tokenizer_config_file = path / 'tokenizer_config.json'
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
if tokenizer:
|
||||
special_bos = (tokenizer_config or {}).get('bos_token')
|
||||
special_cls = (tokenizer_config or {}).get('cls_token')
|
||||
special_eos = (tokenizer_config or {}).get('eos_token')
|
||||
special_sep = (tokenizer_config or {}).get('sep_token')
|
||||
if not special_bos and special_cls and tokenizer_config:
|
||||
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||
if not special_eos and special_sep and tokenizer_config:
|
||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||
if post_processor := tokenizer.get('post_processor'):
|
||||
for processor in post_processor.get('processors', [post_processor]):
|
||||
if processor.get('type') == 'RobertaProcessing':
|
||||
self.add_special_token['bos'] = True
|
||||
self.add_special_token['eos'] = True
|
||||
self.add_special_token['sep'] = True
|
||||
if not special_cls and tokenizer_config:
|
||||
special_cls = processor.get('cls', [special_bos])[0]
|
||||
tokenizer_config['cls_token'] = special_cls
|
||||
if not special_sep and tokenizer_config:
|
||||
special_sep = processor.get('sep', [special_eos])[0]
|
||||
tokenizer_config['sep_token'] = special_sep
|
||||
continue
|
||||
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
||||
# Only works with simple templates, **will** get it wrong on unusual sequences
|
||||
if processor.get('type') == 'TemplateProcessing':
|
||||
tmpl_single = processor.get('single', [])
|
||||
tmpl_pair = processor.get('pair', [])
|
||||
special_first = None
|
||||
special_last = None
|
||||
if len(tmpl_single) > 1:
|
||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_bos = special_first
|
||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||
if special_first not in (special_bos, special_cls):
|
||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_eos = special_last
|
||||
elif special_last != special_eos:
|
||||
if 'eot' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
||||
tokenizer_config['eot_token'] = special_eos
|
||||
elif 'eom' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
||||
tokenizer_config['eom_token'] = special_eos
|
||||
else:
|
||||
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
||||
tokenizer_config['eos_token'] = special_eos = special_last
|
||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||
if special_last != special_eos:
|
||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||
if tmpl_pair:
|
||||
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
||||
if tmpl_a != 'A' or tmpl_b != 'B':
|
||||
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
||||
# A [sep] [eos] B
|
||||
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
||||
add_sep = False
|
||||
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos) and not special_last:
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
||||
if len(tmpl_pair) == 2:
|
||||
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos):
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
||||
self.add_special_token['sep'] = add_sep
|
||||
if add_sep and not special_sep and tokenizer_config:
|
||||
tokenizer_config['sep_token'] = special_eos
|
||||
continue
|
||||
if not tokenizer_config:
|
||||
return True
|
||||
chat_template_alt = None
|
||||
chat_template_json = path / 'chat_template.json'
|
||||
chat_template_jinja = path / 'chat_template.jinja'
|
||||
if chat_template_jinja.is_file():
|
||||
with open(chat_template_jinja, encoding = 'utf-8') as f:
|
||||
chat_template_alt = f.read()
|
||||
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
|
||||
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
|
||||
for template_path in additional_templates:
|
||||
with open(template_path, encoding = 'utf-8') as fp:
|
||||
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
|
||||
elif chat_template_json.is_file():
|
||||
with open(chat_template_json, encoding = 'utf-8') as f:
|
||||
chat_template_alt = json.load(f).get('chat_template')
|
||||
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||
self.chat_template = chat_template
|
||||
else:
|
||||
logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
|
||||
for typ in self.special_token_types:
|
||||
add_entry = tokenizer_config.get(f'add_{typ}_token')
|
||||
if isinstance(add_entry, bool):
|
||||
self.add_special_token[typ] = add_entry
|
||||
entry = tokenizer_config.get(f'{typ}_token')
|
||||
if isinstance(entry, str):
|
||||
tc_content = entry
|
||||
elif isinstance(entry, dict):
|
||||
entry_content = entry.get('content')
|
||||
if not isinstance(entry_content, str):
|
||||
continue
|
||||
tc_content = entry_content
|
||||
else:
|
||||
continue
|
||||
# We only need the first match here.
|
||||
maybe_token_id = next(
|
||||
(atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
|
||||
None,
|
||||
)
|
||||
self._set_special_token(typ, maybe_token_id)
|
||||
return True
|
||||
|
||||
def _try_load_from_config_json(self, path: Path) -> bool:
|
||||
config_file = path / 'config.json'
|
||||
if not config_file.is_file():
|
||||
return False
|
||||
with open(config_file, encoding = 'utf-8') as f:
|
||||
config = json.load(f)
|
||||
for typ in self.special_token_types:
|
||||
token_id = config.get(f'{typ}_token_id')
|
||||
# If not found at root, check in text_config (for multimodal models like Kimi-VL)
|
||||
if token_id is None and 'text_config' in config:
|
||||
token_id = config['text_config'].get(f'{typ}_token_id')
|
||||
self._set_special_token(typ, token_id)
|
||||
return True
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseVocab(Protocol):
|
||||
tokenizer_model: ClassVar[str]
|
||||
name: ClassVar[str]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Vocab(BaseVocab, Protocol):
|
||||
vocab_size: int
|
||||
added_tokens_dict: dict[str, int]
|
||||
added_tokens_list: list[str]
|
||||
fname_tokenizer: Path
|
||||
|
||||
def __init__(self, base_path: Path): ...
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
|
||||
|
||||
|
||||
class NoVocab(BaseVocab):
|
||||
tokenizer_model = "no_vocab"
|
||||
name = "no_vocab"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<NoVocab for a model without integrated vocabulary>"
|
||||
|
||||
|
||||
class BpeVocab(Vocab):
|
||||
tokenizer_model = "gpt2"
|
||||
name = "bpe"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
added_tokens: dict[str, int] = {}
|
||||
|
||||
if (fname_tokenizer := base_path / 'vocab.json').exists():
|
||||
# "slow" tokenizer
|
||||
with open(fname_tokenizer, encoding="utf-8") as f:
|
||||
self.vocab = json.load(f)
|
||||
|
||||
try:
|
||||
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
|
||||
with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
|
||||
added_tokens = json.load(f)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
else:
|
||||
# "fast" tokenizer
|
||||
fname_tokenizer = base_path / 'tokenizer.json'
|
||||
|
||||
# if this fails, FileNotFoundError propagates to caller
|
||||
with open(fname_tokenizer, encoding="utf-8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
|
||||
tokenizer_model: dict[str, Any] = tokenizer_json['model']
|
||||
if (
|
||||
tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
|
||||
or tokenizer_json['decoder']['type'] != 'ByteLevel'
|
||||
):
|
||||
raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
|
||||
|
||||
self.vocab = tokenizer_model["vocab"]
|
||||
|
||||
if (added := tokenizer_json.get('added_tokens')) is not None:
|
||||
# Added tokens here can be duplicates of the main vocabulary.
|
||||
added_tokens = {item['content']: item['id']
|
||||
for item in added
|
||||
if item['content'] not in self.vocab}
|
||||
|
||||
vocab_size = len(self.vocab)
|
||||
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
||||
actual_ids = sorted(added_tokens.values())
|
||||
if expected_ids != actual_ids:
|
||||
expected_end_id = vocab_size + len(actual_ids) - 1
|
||||
raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
|
||||
f"{vocab_size} - {expected_end_id}; got {actual_ids}")
|
||||
|
||||
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
|
||||
self.added_tokens_dict = added_tokens
|
||||
self.added_tokens_list = [text for (text, idx) in items]
|
||||
self.vocab_size_base = vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
|
||||
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
|
||||
|
||||
for i, _ in enumerate(self.vocab):
|
||||
yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
score = -1000.0
|
||||
yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
yield from self.bpe_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class SentencePieceVocab(Vocab):
|
||||
tokenizer_model = "llama"
|
||||
name = "spm"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
if SentencePieceProcessor is None:
|
||||
raise RuntimeError("sentencepiece is not installed")
|
||||
|
||||
added_tokens: dict[str, int] = {}
|
||||
if (fname_tokenizer := base_path / 'tokenizer.model').exists():
|
||||
# normal location
|
||||
try:
|
||||
with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
|
||||
added_tokens = json.load(f)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
|
||||
# not found in alternate location either
|
||||
raise FileNotFoundError('Cannot find tokenizer.model')
|
||||
|
||||
self.sentencepiece_tokenizer = SentencePieceProcessor()
|
||||
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
|
||||
vocab_size = self.sentencepiece_tokenizer.vocab_size()
|
||||
|
||||
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
|
||||
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
|
||||
actual_new_ids = sorted(new_tokens.keys())
|
||||
|
||||
if expected_new_ids != actual_new_ids:
|
||||
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
|
||||
|
||||
# Token pieces that were added to the base vocabulary.
|
||||
self.added_tokens_dict = added_tokens
|
||||
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
|
||||
self.vocab_size_base = vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
|
||||
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
tokenizer = self.sentencepiece_tokenizer
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.IdToPiece(i)
|
||||
text = piece.encode("utf-8")
|
||||
score: float = tokenizer.GetScore(i)
|
||||
|
||||
toktype = gguf.TokenType.NORMAL
|
||||
if tokenizer.IsUnknown(i):
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
if tokenizer.IsControl(i):
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
|
||||
# NOTE: I think added_tokens are user defined.
|
||||
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
|
||||
|
||||
if tokenizer.IsUnused(i):
|
||||
toktype = gguf.TokenType.UNUSED
|
||||
if tokenizer.IsByte(i):
|
||||
toktype = gguf.TokenType.BYTE
|
||||
|
||||
yield text, score, toktype
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
score = -1000.0
|
||||
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
yield from self.sentencepiece_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class LlamaHfVocab(Vocab):
|
||||
tokenizer_model = "llama"
|
||||
name = "hfft"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
fname_tokenizer = base_path / 'tokenizer.json'
|
||||
# if this fails, FileNotFoundError propagates to caller
|
||||
with open(fname_tokenizer, encoding='utf-8') as f:
|
||||
tokenizer_json = json.load(f)
|
||||
|
||||
# pre-check so we know if we need transformers
|
||||
tokenizer_model: dict[str, Any] = tokenizer_json['model']
|
||||
is_llama3 = (
|
||||
tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
|
||||
and not tokenizer_model.get('byte_fallback', True)
|
||||
)
|
||||
if is_llama3:
|
||||
raise TypeError('Llama 3 must be converted with BpeVocab')
|
||||
|
||||
if not is_llama3 and (
|
||||
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
|
||||
or tokenizer_json['decoder']['type'] != 'Sequence'
|
||||
):
|
||||
raise FileNotFoundError('Cannot find Llama BPE tokenizer')
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use LlamaHfVocab, please install the `transformers` package. "
|
||||
"You can install it with `pip install transformers`."
|
||||
) from e
|
||||
|
||||
# Allow the tokenizer to default to slow or fast versions.
|
||||
# Explicitly set tokenizer to use local paths.
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_path,
|
||||
cache_dir=base_path,
|
||||
local_files_only=True,
|
||||
)
|
||||
assert self.tokenizer.is_fast # assume tokenizer.json is used
|
||||
|
||||
# Initialize lists and dictionaries for added tokens
|
||||
self.added_tokens_list = []
|
||||
self.added_tokens_dict = dict()
|
||||
self.added_tokens_ids = set()
|
||||
|
||||
# Process added tokens
|
||||
for tok, tokidx in sorted(
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||
):
|
||||
# Only consider added tokens that are not in the base vocabulary
|
||||
if tokidx >= self.tokenizer.vocab_size:
|
||||
self.added_tokens_list.append(tok)
|
||||
self.added_tokens_dict[tok] = tokidx
|
||||
self.added_tokens_ids.add(tokidx)
|
||||
|
||||
# Store special tokens and their IDs
|
||||
self.specials = {
|
||||
tok: self.tokenizer.get_vocab()[tok]
|
||||
for tok in self.tokenizer.all_special_tokens
|
||||
}
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
|
||||
# Set vocabulary sizes
|
||||
self.vocab_size_base = self.tokenizer.vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
|
||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
|
||||
for token_id in range(self.vocab_size_base):
|
||||
# Skip processing added tokens here
|
||||
if token_id in self.added_tokens_ids:
|
||||
continue
|
||||
|
||||
# Convert token text to bytes
|
||||
token_text = reverse_vocab[token_id].encode("utf-8")
|
||||
|
||||
# Yield token text, score, and type
|
||||
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||
token_id, token_text, self.special_ids # Reuse already stored special IDs
|
||||
)
|
||||
|
||||
def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
|
||||
# Special case for byte tokens
|
||||
if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return gguf.TokenType.BYTE
|
||||
|
||||
# Determine token type based on whether it's a special token
|
||||
return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
|
||||
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
# Placeholder for actual logic to determine the token's score
|
||||
# This needs to be implemented based on specific requirements
|
||||
return -1000.0 # Default score
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
|
||||
score = self.get_token_score(self.specials[text])
|
||||
else:
|
||||
toktype = gguf.TokenType.USER_DEFINED
|
||||
score = -1000.0
|
||||
|
||||
yield text.encode("utf-8"), score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class MistralTokenizerType(str, Enum):
|
||||
spm = "spm"
|
||||
tekken = "tekken"
|
||||
|
||||
|
||||
# Copied from Transformers (Apache 2.0)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
|
||||
|
||||
def bytes_to_unicode() -> dict[int, str]:
|
||||
"""
|
||||
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
||||
characters the bpe code barfs on.
|
||||
|
||||
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
||||
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
||||
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
||||
tables between utf-8 bytes and unicode strings.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs_str = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs_str))
|
||||
|
||||
|
||||
class MistralVocab(Vocab):
|
||||
tokenizer_model = "mistral"
|
||||
name = "mistral"
|
||||
|
||||
added_tokens_dict: dict[str, int] = {}
|
||||
added_tokens_list: list[str] = []
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(
|
||||
"To use MistralVocab, please install the `mistral-common` package. "
|
||||
"You can install it with `pip install mistral-common`."
|
||||
)
|
||||
assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
|
||||
assert MistralTokenizer is not None, "mistral_common is not installed"
|
||||
assert Tekkenizer is not None, "mistral_common is not installed"
|
||||
|
||||
logger.info(f"Loading Mistral tokenizer from {base_path}")
|
||||
|
||||
# Find the tokenizer files
|
||||
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
|
||||
|
||||
if get_one_valid_tokenizer_file is not None:
|
||||
tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
|
||||
else:
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
|
||||
tokenizer_file_path = base_path / tokenizer_file
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_file(
|
||||
tokenizer_file_path
|
||||
).instruct_tokenizer.tokenizer
|
||||
self.tokenizer_type = (
|
||||
MistralTokenizerType.tekken
|
||||
if isinstance(self.tokenizer, Tekkenizer)
|
||||
else MistralTokenizerType.spm
|
||||
)
|
||||
self.vocab_size = self.tokenizer.n_words
|
||||
self.fname_tokenizer = tokenizer_file_path
|
||||
self._name = (
|
||||
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def gguf_tokenizer_model(self) -> str:
|
||||
return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
|
||||
|
||||
def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
assert SentencePieceTokenizer is not None, "mistral_common is not installed"
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), (
|
||||
f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
|
||||
for i in range(self.tokenizer._model.vocab_size()):
|
||||
piece = self.tokenizer._model.IdToPiece(i)
|
||||
text = piece.encode("utf-8")
|
||||
score: float = self.tokenizer._model.GetScore(i)
|
||||
|
||||
toktype = gguf.TokenType.NORMAL
|
||||
if self.tokenizer._model.IsUnknown(i):
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
if self.tokenizer._model.IsControl(i):
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
|
||||
if self.tokenizer._model.IsUnused(i):
|
||||
toktype = gguf.TokenType.UNUSED
|
||||
if self.tokenizer._model.IsByte(i):
|
||||
toktype = gguf.TokenType.BYTE
|
||||
|
||||
yield text, score, toktype
|
||||
|
||||
def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
assert Tekkenizer is not None, "mistral_common is not installed"
|
||||
assert isinstance(self.tokenizer, Tekkenizer), (
|
||||
f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
for token_id in range(self.tokenizer.num_special_tokens):
|
||||
yield (
|
||||
self.tokenizer.id_to_piece(token_id).encode("utf-8"),
|
||||
0,
|
||||
gguf.TokenType.CONTROL
|
||||
)
|
||||
for token in self.tokenizer._tekken_token2id_nospecial:
|
||||
yield (
|
||||
self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
|
||||
0,
|
||||
gguf.TokenType.NORMAL,
|
||||
)
|
||||
|
||||
def get_token_id(self, token: str) -> int:
|
||||
assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
|
||||
if self.tokenizer_type == MistralTokenizerType.spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer)
|
||||
return self.tokenizer._vocab.index(token)
|
||||
elif self.tokenizer_type == MistralTokenizerType.tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer)
|
||||
return (
|
||||
self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
||||
|
||||
@property
|
||||
def bos_id(self) -> int:
|
||||
return self.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
if self.tokenizer.pad_id == -1:
|
||||
return self.eos_id
|
||||
return self.tokenizer.pad_id
|
||||
|
||||
@property
|
||||
def unk_id(self) -> int:
|
||||
return self.tokenizer.unk_id
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
if self.tokenizer_type == MistralTokenizerType.spm:
|
||||
yield from self._sentencepiece_tokens()
|
||||
|
||||
elif self.tokenizer_type == MistralTokenizerType.tekken:
|
||||
yield from self._tekken_tokens()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b, byte_encoder):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
def extract_vocab_merges_from_model(self):
|
||||
# Adapted from Transformers (Apache 2.0)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
|
||||
assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
|
||||
f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
mergeable_ranks = self.tokenizer._model._mergeable_ranks
|
||||
token_bytes_map = {
|
||||
rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
|
||||
}
|
||||
merge_pairs = []
|
||||
|
||||
# Sort vocab by rank to ensure correct merge order
|
||||
for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
|
||||
merged_token = token_bytes_map[i]
|
||||
local = []
|
||||
for j in range(1, len(merged_token)):
|
||||
left = merged_token[:j]
|
||||
right = merged_token[j:]
|
||||
if (
|
||||
left in mergeable_ranks
|
||||
and right in mergeable_ranks
|
||||
and (left + right) in mergeable_ranks
|
||||
):
|
||||
local.append((left, right, i))
|
||||
if not local:
|
||||
raise ValueError(
|
||||
f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
|
||||
)
|
||||
local = sorted(
|
||||
local,
|
||||
key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
|
||||
reverse=False,
|
||||
)
|
||||
merge_pairs.extend(local)
|
||||
merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
decoded_merge_pairs = [
|
||||
[
|
||||
self.token_bytes_to_string(val[0], byte_encoder),
|
||||
self.token_bytes_to_string(val[1], byte_encoder),
|
||||
]
|
||||
for val in merge_pairs
|
||||
]
|
||||
|
||||
merges = [
|
||||
" ".join(
|
||||
[
|
||||
# ensure the spaces are properly encoded
|
||||
"".join(chr(ord(c) + 256) if c == " " else c for c in part)
|
||||
for part in pair
|
||||
]
|
||||
)
|
||||
for pair in decoded_merge_pairs
|
||||
]
|
||||
|
||||
return merges
|
||||
43
gguf-py/pyproject.toml
Normal file
43
gguf-py/pyproject.toml
Normal file
@@ -0,0 +1,43 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.17.1"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
{include = "gguf"},
|
||||
{include = "gguf/py.typed"},
|
||||
]
|
||||
readme = "README.md"
|
||||
homepage = "https://ggml.ai"
|
||||
repository = "https://github.com/ggml-org/llama.cpp"
|
||||
keywords = ["ggml", "gguf", "llama.cpp"]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8"
|
||||
numpy = ">=1.17"
|
||||
tqdm = ">=4.27"
|
||||
pyyaml = ">=5.1"
|
||||
sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
|
||||
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^5.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
gui = ["PySide6"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
gguf-convert-endian = "gguf.scripts.gguf_convert_endian:main"
|
||||
gguf-dump = "gguf.scripts.gguf_dump:main"
|
||||
gguf-set-metadata = "gguf.scripts.gguf_set_metadata:main"
|
||||
gguf-new-metadata = "gguf.scripts.gguf_new_metadata:main"
|
||||
gguf-editor-gui = "gguf.scripts.gguf_editor_gui:main"
|
||||
1
gguf-py/tests/__init__.py
Normal file
1
gguf-py/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .test_metadata import *
|
||||
238
gguf-py/tests/test_metadata.py
Executable file
238
gguf-py/tests/test_metadata.py
Executable file
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
|
||||
class TestMetadataMethod(unittest.TestCase):
|
||||
|
||||
def test_id_to_title(self):
|
||||
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
|
||||
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
|
||||
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
|
||||
|
||||
def test_get_model_id_components(self):
|
||||
# This is the basic standard form with organization marker
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
|
||||
# Similar to basic standard form but without organization marker
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
|
||||
# Missing version
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
|
||||
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
|
||||
|
||||
# Missing finetune
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
|
||||
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
|
||||
|
||||
# Base name and size label only
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
|
||||
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
|
||||
|
||||
# Base name and version only
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
|
||||
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
|
||||
|
||||
## Edge Cases ##
|
||||
|
||||
# This is too ambiguous... best to err on caution and output nothing
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
|
||||
('Mixtral', None, None, None, None, None))
|
||||
|
||||
# Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
|
||||
|
||||
# Non standard naming
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
||||
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
||||
|
||||
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
|
||||
('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
|
||||
|
||||
# Check that it can handle a real model id with no version code
|
||||
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
|
||||
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
|
||||
|
||||
# There is some legitimate models with only thousands of parameters
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
|
||||
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
|
||||
|
||||
# Non standard and not easy to disambiguate
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
||||
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
||||
|
||||
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
|
||||
('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
|
||||
|
||||
# This is a real model id where the weight size has a decimal point
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
|
||||
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
|
||||
|
||||
# Uses an underscore in the size label
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
|
||||
('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
|
||||
|
||||
# Uses Iter3 for the version
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
|
||||
('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
|
||||
|
||||
# Has two potential versions in the basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
|
||||
('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
|
||||
|
||||
# Potential version in the basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
|
||||
('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
|
||||
|
||||
# Underscore in the basename, and 1m for the context size
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
|
||||
('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
|
||||
|
||||
# Version before the finetune name
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
|
||||
('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
|
||||
|
||||
# TODO: hf suffix which could be ignored but isn't
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
|
||||
('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
|
||||
|
||||
# Two sizes, don't merge them, the other is the number of tokens on which it was trained
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
|
||||
('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
|
||||
|
||||
# It's a trap, there is no size label
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
|
||||
('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
|
||||
|
||||
# Weird size notation
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
|
||||
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
|
||||
|
||||
# Ignore full-text size labels when there are number-based ones, and deduplicate size labels
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
|
||||
('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
|
||||
|
||||
# Instruct in a name without a size label
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"),
|
||||
('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None))
|
||||
|
||||
# Non-obvious splitting relying on 'chat' keyword
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"),
|
||||
('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None))
|
||||
|
||||
# Multiple versions
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"),
|
||||
('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B'))
|
||||
|
||||
# TODO: DPO in the name
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"),
|
||||
('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B'))
|
||||
|
||||
# DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"),
|
||||
('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B'))
|
||||
|
||||
# Too ambiguous
|
||||
# TODO: should "base" be a 'finetune' or 'size_label'?
|
||||
# (in this case it should be a size label, but other models use it to signal that they are not finetuned)
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"),
|
||||
('Florence-2-base', 'microsoft', None, None, None, None))
|
||||
|
||||
## Invalid cases ##
|
||||
|
||||
# Start with a dash and has dashes in rows
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
|
||||
('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
|
||||
|
||||
## LoRA ##
|
||||
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"),
|
||||
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B'))
|
||||
|
||||
# Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234),
|
||||
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B'))
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_card(self):
|
||||
model_card = {
|
||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||
'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
|
||||
'language': ['en'],
|
||||
'datasets': ['teknium/OpenHermes-2.5'],
|
||||
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
|
||||
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
|
||||
}
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
expect = gguf.Metadata()
|
||||
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
||||
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
||||
expect.languages=['en']
|
||||
expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is inferred from model id
|
||||
model_card = {'base_models': 'teknium/OpenHermes-2.5'}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is only url
|
||||
model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is given directly
|
||||
model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is inferred from model id
|
||||
model_card = {'datasets': 'teknium/OpenHermes-2.5'}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is only url
|
||||
model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is given directly
|
||||
model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
|
||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_dir(self):
|
||||
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
|
||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
247
gguf-py/tests/test_quants.py
Executable file
247
gguf-py/tests/test_quants.py
Executable file
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
|
||||
|
||||
# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from math import prod
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import ctypes
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import gguf
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
|
||||
|
||||
logger = logging.getLogger("test-quants")
|
||||
|
||||
|
||||
c_float_p = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
|
||||
class ggml_init_params(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("mem_size", ctypes.c_size_t),
|
||||
("mem_buffer", ctypes.c_void_p),
|
||||
("no_alloc", ctypes.c_bool),
|
||||
]
|
||||
|
||||
|
||||
class GGMLQuants:
|
||||
libggml: ctypes.CDLL
|
||||
|
||||
def __init__(self, libggml: Path):
|
||||
self.libggml = ctypes.CDLL(str(libggml))
|
||||
self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
|
||||
# enum ggml_type type,
|
||||
# const float * src,
|
||||
# void * dst,
|
||||
# int64_t start,
|
||||
# int64_t nrows,
|
||||
# int64_t n_per_row,
|
||||
# const float * imatrix) {
|
||||
self.libggml.ggml_quantize_chunk.argtypes = (
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int64,
|
||||
ctypes.c_int64,
|
||||
ctypes.c_int64,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
)
|
||||
|
||||
self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
|
||||
self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
|
||||
|
||||
for t in (
|
||||
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
|
||||
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
|
||||
"tq1_0", "tq2_0",
|
||||
"mxfp4",
|
||||
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
|
||||
"iq4_nl", "iq4_xs",
|
||||
):
|
||||
dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
|
||||
dequant_func.restype = None
|
||||
dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
|
||||
|
||||
self.libggml.ggml_fp16_to_fp32_row.restype = None
|
||||
self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
|
||||
self.libggml.ggml_bf16_to_fp32_row.restype = None
|
||||
self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
|
||||
|
||||
self.libggml.ggml_init.argtypes = (ggml_init_params,)
|
||||
|
||||
self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
|
||||
|
||||
def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
|
||||
result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
|
||||
if qtype == GGMLQuantizationType.F32:
|
||||
# no-op
|
||||
result = tensor.view(np.float32)
|
||||
elif qtype == GGMLQuantizationType.F16:
|
||||
self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
|
||||
elif qtype == GGMLQuantizationType.BF16:
|
||||
self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
|
||||
else:
|
||||
lw_qname = qtype.name.lower()
|
||||
if lw_qname[-1] == "k":
|
||||
lw_qname = lw_qname[:-1] + "K"
|
||||
dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
|
||||
dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
|
||||
return result
|
||||
|
||||
def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
|
||||
result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
|
||||
if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
|
||||
# TODO: is a column-wise sum of squares appropriate?
|
||||
qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
|
||||
else:
|
||||
qw = ctypes.cast(0, c_float_p)
|
||||
result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
|
||||
assert result.size == result_size
|
||||
return result
|
||||
|
||||
|
||||
def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
|
||||
same = np.array_equal(t1, t2)
|
||||
if same:
|
||||
return True
|
||||
else:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
if t1.dtype == np.float32:
|
||||
t1 = t1.reshape((-1, block_size))
|
||||
t2 = t2.reshape((-1, block_size))
|
||||
else:
|
||||
t1 = t1.reshape((-1, type_size))
|
||||
t2 = t2.reshape((-1, type_size))
|
||||
x = t1.view(np.uint8) ^ t2.view(np.uint8)
|
||||
diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
|
||||
num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
|
||||
if num_bad_blocks == 0 and t1.shape == t2.shape:
|
||||
logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
|
||||
return True
|
||||
logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
|
||||
bad_block_id = np.argmax(diff_bits, axis=0)
|
||||
logger.debug(f"Worst block id: {bad_block_id}")
|
||||
logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
|
||||
|
||||
sum_diff_bits = np.sum(diff_bits)
|
||||
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
|
||||
return False
|
||||
|
||||
|
||||
def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
|
||||
ggml_quants = GGMLQuants(libggml_path)
|
||||
|
||||
np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
|
||||
|
||||
r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
|
||||
# test zero blocks
|
||||
r[0, 0, :] = 0
|
||||
## Maybe test infinities? (can make NANs, not really useful in practice)
|
||||
# r[0, 1, 0] = np.inf
|
||||
# r[0, 2, 0] = -np.inf
|
||||
# r[0, 3, 0] = np.inf
|
||||
# r[0, 3, 1] = -np.inf
|
||||
|
||||
for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
|
||||
has_dequantize = False
|
||||
has_quantize = False
|
||||
|
||||
try:
|
||||
gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
|
||||
has_dequantize = True
|
||||
except (NotImplementedError, AssertionError) as e:
|
||||
if isinstance(e, AssertionError):
|
||||
logger.error(f"Error with {qtype.name}: {e}")
|
||||
raise e
|
||||
try:
|
||||
gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
|
||||
has_quantize = True
|
||||
except (NotImplementedError, AssertionError) as e:
|
||||
if isinstance(e, AssertionError):
|
||||
logger.error(f"Error with {qtype.name}: {e}")
|
||||
raise e
|
||||
|
||||
if not has_dequantize and not has_quantize:
|
||||
continue
|
||||
|
||||
logger.info(f"Testing {qtype.name}")
|
||||
|
||||
rc = r.copy(order="C")
|
||||
|
||||
pyq = None
|
||||
ggq = None
|
||||
|
||||
if has_quantize:
|
||||
logger.debug(f"Quantizing to {qtype.name} with Python")
|
||||
pyq = gguf.quants.quantize(rc, qtype)
|
||||
|
||||
logger.debug(f"Quantizing to {qtype.name} with C")
|
||||
ggq = ggml_quants.quantize(rc, qtype)
|
||||
|
||||
if qtype == GGMLQuantizationType.F16:
|
||||
pyq = pyq.view(np.uint8)
|
||||
quant_equal = compare_tensors(pyq, ggq, qtype)
|
||||
|
||||
if not quant_equal:
|
||||
logger.error(f"Quantization to {qtype.name} does not match ❌")
|
||||
else:
|
||||
logger.info(f"Quantization to {qtype.name} matches exactly ✅")
|
||||
|
||||
if has_dequantize:
|
||||
if ggq is None and not quick:
|
||||
logger.debug(f"Quantizing to {qtype.name} with C")
|
||||
ggq = ggml_quants.quantize(rc, qtype)
|
||||
|
||||
if ggq is not None:
|
||||
logger.debug(f"Dequantizing from {qtype.name} with Python")
|
||||
pydq = gguf.quants.dequantize(ggq, qtype)
|
||||
logger.debug(f"Dequantizing from {qtype.name} with C")
|
||||
ggdq = ggml_quants.dequantize(ggq, qtype)
|
||||
|
||||
dequant_equal = compare_tensors(pydq, ggdq, qtype)
|
||||
|
||||
if not dequant_equal:
|
||||
logger.error(f"Dequantization from {qtype.name} does not match ❌")
|
||||
else:
|
||||
logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
|
||||
|
||||
rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
|
||||
rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
|
||||
|
||||
logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
|
||||
pydq = gguf.quants.dequantize(rq, qtype)
|
||||
logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
|
||||
ggdq = ggml_quants.dequantize(rq, qtype)
|
||||
|
||||
dequant_equal = compare_tensors(pydq, ggdq, qtype)
|
||||
|
||||
if not dequant_equal:
|
||||
logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
|
||||
else:
|
||||
logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
|
||||
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
|
||||
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
|
||||
parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)
|
||||
Reference in New Issue
Block a user