Support FP8 E4M3 KV Cache (#2786)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
@@ -277,6 +278,29 @@ class ModelRunner:
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
self.model.load_kv_cache_scales(
|
||||
self.server_args.quantization_param_path
|
||||
)
|
||||
logger.info(
|
||||
"Loaded KV cache scaling factors from %s",
|
||||
self.server_args.quantization_param_path,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using FP8 KV cache and scaling factors provided but "
|
||||
"model %s does not support loading scaling factors.",
|
||||
self.model.__class__,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Using FP8 KV cache but no scaling factors "
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!"
|
||||
)
|
||||
|
||||
# Parse other args
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
@@ -516,6 +540,9 @@ class ModelRunner:
|
||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if is_cuda():
|
||||
self.kv_cache_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||
|
||||
Reference in New Issue
Block a user