Fix the FP8 E4M3 parsing offline scales failure bug (#3045)
This commit is contained in:
@@ -27,6 +27,7 @@ import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
return name
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
|
||||
class KVCacheQuantSchema(BaseModel):
|
||||
dtype: str
|
||||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: Dict[int, Dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
assert self.dtype == "float8_e4m3fn", (
|
||||
"Loaded scaling factors intended for KV cache dtype = "
|
||||
f"{self.dtype} rather than float8_e4m3fn!"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_size = context["tp_size"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
assert len(self.scaling_factor) == tp_size, (
|
||||
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||
f"but LLM engine is currently running with TP size {tp_size}."
|
||||
)
|
||||
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||
assert len(layer_maps) == num_hidden_layers, (
|
||||
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||
f"Expected {num_hidden_layers} layers, got "
|
||||
f"{len(layer_maps)}."
|
||||
)
|
||||
for i in range(tp_size):
|
||||
assert (
|
||||
i in self.scaling_factor
|
||||
), f"KV cache scales map for TP rank {i} not found."
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_rank = context["tp_rank"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
layer_scales_map = self.scaling_factor[tp_rank]
|
||||
for i in range(num_hidden_layers):
|
||||
assert i in layer_scales_map, (
|
||||
f"Could not find KV cache scales for layer {i} in "
|
||||
f"TP rank {tp_rank}."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class QuantParamSchema(BaseModel):
|
||||
# TODO: Generalize and extend with more fields
|
||||
# (e.g. weights/activations params) once functionality is enabled
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_type: Optional[str]
|
||||
kv_cache: KVCacheQuantSchema
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
model_type = context.get("model_type", None)
|
||||
if model_type is not None:
|
||||
assert model_type == self.model_type, (
|
||||
f"Model type is {model_type} but loaded "
|
||||
f"scaling factors belonging to different "
|
||||
f"model type {self.model_type}!"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str,
|
||||
tp_rank: int,
|
||||
@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding JSON in file '%s'.", filename)
|
||||
except Exception:
|
||||
logger.exception("An error occurred while reading '%s'.", filename)
|
||||
logger.error("An error occurred while reading '%s'.", filename)
|
||||
# This section is reached if and only if any of the excepts are hit
|
||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||
# which ultimately defaults to 1.0 scales
|
||||
|
||||
Reference in New Issue
Block a user