Fix the FP8 E4M3 parsing offline scales failure bug (#3045)
This commit is contained in:
@@ -27,6 +27,7 @@ import huggingface_hub.constants
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||||
|
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||||
from safetensors.torch import load_file, safe_open, save_file
|
from safetensors.torch import load_file, safe_open, save_file
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
|
||||||
|
class KVCacheQuantSchema(BaseModel):
|
||||||
|
dtype: str
|
||||||
|
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||||
|
# layer indices to their per-tensor KV cache scaling factor.
|
||||||
|
# TODO: Consider pulling this and its validation methods out into its
|
||||||
|
# own schema class (tricky as its members are variable)
|
||||||
|
scaling_factor: Dict[int, Dict[int, float]]
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||||
|
assert self.dtype == "float8_e4m3fn", (
|
||||||
|
"Loaded scaling factors intended for KV cache dtype = "
|
||||||
|
f"{self.dtype} rather than float8_e4m3fn!"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
tp_size = context["tp_size"]
|
||||||
|
num_hidden_layers = context["num_hidden_layers"]
|
||||||
|
assert len(self.scaling_factor) == tp_size, (
|
||||||
|
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||||
|
f"but LLM engine is currently running with TP size {tp_size}."
|
||||||
|
)
|
||||||
|
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||||
|
assert len(layer_maps) == num_hidden_layers, (
|
||||||
|
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||||
|
f"Expected {num_hidden_layers} layers, got "
|
||||||
|
f"{len(layer_maps)}."
|
||||||
|
)
|
||||||
|
for i in range(tp_size):
|
||||||
|
assert (
|
||||||
|
i in self.scaling_factor
|
||||||
|
), f"KV cache scales map for TP rank {i} not found."
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
tp_rank = context["tp_rank"]
|
||||||
|
num_hidden_layers = context["num_hidden_layers"]
|
||||||
|
layer_scales_map = self.scaling_factor[tp_rank]
|
||||||
|
for i in range(num_hidden_layers):
|
||||||
|
assert i in layer_scales_map, (
|
||||||
|
f"Could not find KV cache scales for layer {i} in "
|
||||||
|
f"TP rank {tp_rank}."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class QuantParamSchema(BaseModel):
|
||||||
|
# TODO: Generalize and extend with more fields
|
||||||
|
# (e.g. weights/activations params) once functionality is enabled
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
model_type: Optional[str]
|
||||||
|
kv_cache: KVCacheQuantSchema
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
model_type = context.get("model_type", None)
|
||||||
|
if model_type is not None:
|
||||||
|
assert model_type == self.model_type, (
|
||||||
|
f"Model type is {model_type} but loaded "
|
||||||
|
f"scaling factors belonging to different "
|
||||||
|
f"model type {self.model_type}!"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def kv_cache_scales_loader(
|
def kv_cache_scales_loader(
|
||||||
filename: str,
|
filename: str,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error("Error decoding JSON in file '%s'.", filename)
|
logger.error("Error decoding JSON in file '%s'.", filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("An error occurred while reading '%s'.", filename)
|
logger.error("An error occurred while reading '%s'.", filename)
|
||||||
# This section is reached if and only if any of the excepts are hit
|
# This section is reached if and only if any of the excepts are hit
|
||||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||||
# which ultimately defaults to 1.0 scales
|
# which ultimately defaults to 1.0 scales
|
||||||
|
|||||||
Reference in New Issue
Block a user