From b8ab989ff4666f4cee2a1d77aa3941d794ffabd7 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Thu, 23 Jan 2025 06:19:33 +0800 Subject: [PATCH] Fix the FP8 E4M3 parsing offline scales failure bug (#3045) --- .../sglang/srt/model_loader/weight_utils.py | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 77c3fcbee..f2f67ecab 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -27,6 +27,7 @@ import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return name +# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self + + def kv_cache_scales_loader( filename: str, tp_rank: int, @@ -681,7 +757,7 @@ def kv_cache_scales_loader( except json.JSONDecodeError: logger.error("Error decoding JSON in file '%s'.", filename) except Exception: - logger.exception("An error occurred while reading '%s'.", filename) + logger.error("An error occurred while reading '%s'.", filename) # This section is reached if and only if any of the excepts are hit # Return an empty iterable (list) => no KV cache scales are loaded # which ultimately defaults to 1.0 scales