gguf-py : fix and simplify quantized shape round-trip (#7483)
* gguf-py : fix and simplify quantized shape round-trip * gguf-py : remove unused import
This commit is contained in:
@@ -12,6 +12,8 @@ from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from .quants import quant_shape_to_byte_shape
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -251,6 +253,7 @@ class GGUFReader:
|
||||
tensor_names.add(tensor_name)
|
||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||
n_elems = int(np.prod(dims))
|
||||
np_dims = tuple(reversed(dims.tolist()))
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elems * type_size // block_size
|
||||
data_offs = int(start_offs + offset_tensor[0])
|
||||
@@ -279,6 +282,7 @@ class GGUFReader:
|
||||
else:
|
||||
item_count = n_bytes
|
||||
item_type = np.uint8
|
||||
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
|
||||
tensors.append(ReaderTensor(
|
||||
name = tensor_name,
|
||||
tensor_type = ggml_type,
|
||||
@@ -286,7 +290,7 @@ class GGUFReader:
|
||||
n_elements = n_elems,
|
||||
n_bytes = n_bytes,
|
||||
data_offset = data_offs,
|
||||
data = self._get(data_offs, item_type, item_count),
|
||||
data = self._get(data_offs, item_type, item_count).reshape(np_dims),
|
||||
field = field,
|
||||
))
|
||||
self.tensors = tensors
|
||||
|
||||
Reference in New Issue
Block a user