Support FP8 E4M3 KV Cache (#2786)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-13 13:17:11 +08:00
committed by GitHub
parent 85b2e05770
commit 0bb0f76311
9 changed files with 205 additions and 10 deletions

View File

@@ -54,6 +54,7 @@ from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_cuda,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_p2p_access_check,
@@ -277,6 +278,29 @@ class ModelRunner:
device_config=DeviceConfig(self.device),
)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
@@ -516,6 +540,9 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if is_cuda():
self.kv_cache_dtype = torch.float8_e4m3fn
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."