Support FP8 E4M3 KV Cache (#2786)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-13 13:17:11 +08:00
committed by GitHub
parent 85b2e05770
commit 0bb0f76311
9 changed files with 205 additions and 10 deletions

View File

@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

View File

@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.k_scale = 1.0
self.v_scale = 1.0
def forward(
self,

View File

@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = (cache_v / v_scale).to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)

View File

@@ -54,6 +54,7 @@ from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_cuda,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_p2p_access_check,
@@ -277,6 +278,29 @@ class ModelRunner:
device_config=DeviceConfig(self.device),
)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
@@ -516,6 +540,9 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if is_cuda():
self.kv_cache_dtype = torch.float8_e4m3fn
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."

View File

@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class LlamaForCausalLM(nn.Module):
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
class Phi3ForCausalLM(LlamaForCausalLM):
pass

View File

@@ -32,6 +32,7 @@ from sglang.srt.utils import (
is_hip,
is_ipv6,
is_port_available,
nullable_str,
)
logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class ServerArgs:
trust_remote_code: bool = True
dtype: str = "auto"
kv_cache_dtype: str = "auto"
quantization_param_path: nullable_str = None
quantization: Optional[str] = None
context_length: Optional[int] = None
device: str = "cuda"
@@ -350,8 +352,17 @@ class ServerArgs:
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument(
"--quantization",

View File

@@ -1375,3 +1375,9 @@ def debug_timing(func):
return func(*args, **kwargs)
return wrapper
def nullable_str(val: str):
if not val or val == "None":
return None
return val