support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)
Co-authored-by: bjmsong <bjmsong@126.com>
This commit is contained in:
@@ -9,7 +9,17 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
num_hidden_layers: int,
|
||||
model_type: Optional[str],
|
||||
) -> Iterable[Tuple[int, float]]:
|
||||
"""
|
||||
A simple utility to read in KV cache scaling factors that have been
|
||||
previously serialized to disk. Used by the model to populate the appropriate
|
||||
KV cache scaling factors. The serialization should represent a dictionary
|
||||
whose keys are the TP ranks and values are another dictionary mapping layers
|
||||
to their KV cache scaling factors.
|
||||
"""
|
||||
try:
|
||||
with open(filename) as f:
|
||||
context = {
|
||||
"model_type": model_type,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"tp_rank": tp_rank,
|
||||
"tp_size": tp_size,
|
||||
}
|
||||
schema_dct = json.load(f)
|
||||
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||
return layer_scales_map.items()
|
||||
except FileNotFoundError:
|
||||
logger.error("File or directory '%s' not found.", filename)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding JSON in file '%s'.", filename)
|
||||
except Exception:
|
||||
logger.exception("An error occurred while reading '%s'.", filename)
|
||||
# This section is reached if and only if any of the excepts are hit
|
||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||
# which ultimately defaults to 1.0 scales
|
||||
logger.warning(
|
||||
"Defaulting to KV cache scaling factors = 1.0 for all "
|
||||
"layers in TP rank %d as an error occurred during loading.",
|
||||
tp_rank,
|
||||
)
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user