support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)

Co-authored-by: bjmsong <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-18 11:43:22 +08:00
committed by GitHub
parent 13387e6b7a
commit d3024f4fc8
8 changed files with 227 additions and 9 deletions

View File

@@ -9,7 +9,17 @@ import logging
import os
import tempfile
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)
import filelock
import gguf
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
return name
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.",
tp_rank,
)
return []