Support FP8 E4M3 KV Cache (#2786)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-13 13:17:11 +08:00
committed by GitHub
parent 85b2e05770
commit 0bb0f76311
9 changed files with 205 additions and 10 deletions

View File

@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
) )
else: else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)

View File

@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
self.k_scale = 1.0
self.v_scale = 1.0
def forward( def forward(
self, self,

View File

@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
if dtype == torch.float8_e5m2: if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8 self.store_dtype = torch.uint8
else: else:
self.store_dtype = dtype self.store_dtype = dtype
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = cache_v.to(self.dtype) cache_v = (cache_v / v_scale).to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)

View File

@@ -54,6 +54,7 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_cuda,
is_hip, is_hip,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
@@ -277,6 +278,29 @@ class ModelRunner:
device_config=DeviceConfig(self.device), device_config=DeviceConfig(self.device),
) )
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args # Parse other args
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
@@ -516,6 +540,9 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e5m2fnuz self.kv_cache_dtype = torch.float8_e5m2fnuz
else: else:
self.kv_cache_dtype = torch.float8_e5m2 self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if is_cuda():
self.kv_cache_dtype = torch.float8_e4m3fn
else: else:
raise ValueError( raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."

View File

@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
pass pass

View File

@@ -32,6 +32,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_ipv6, is_ipv6,
is_port_available, is_port_available,
nullable_str,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class ServerArgs:
trust_remote_code: bool = True trust_remote_code: bool = True
dtype: str = "auto" dtype: str = "auto"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
quantization_param_path: nullable_str = None
quantization: Optional[str] = None quantization: Optional[str] = None
context_length: Optional[int] = None context_length: Optional[int] = None
device: str = "cuda" device: str = "cuda"
@@ -350,8 +352,17 @@ class ServerArgs:
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
default=ServerArgs.kv_cache_dtype, default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"], choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
) )
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",

View File

@@ -1375,3 +1375,9 @@ def debug_timing(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def nullable_str(val: str):
if not val or val == "None":
return None
return val

View File

@@ -0,0 +1,42 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 1,
"1": 1,
"2": 1,
"3": 1,
"4": 1,
"5": 1,
"6": 1,
"7": 1,
"8": 1,
"9": 1,
"10": 1,
"11": 1,
"12": 1,
"13": 1,
"14": 1,
"15": 1,
"16": 1,
"17": 1,
"18": 1,
"19": 1,
"20": 1,
"21": 1,
"22": 1,
"23": 1,
"24": 1,
"25": 1,
"26": 1,
"27": 1,
"28": 1,
"29": 1,
"30": 1,
"31": 1
}
}
}
}

View File

@@ -0,0 +1,64 @@
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestFp8Kvcache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-cache-dtype",
"fp8_e4m3",
"--quantization-param-path",
config_file,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()