Support FP8 E4M3 KV Cache (#2786)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.k_scale = 1.0
|
||||
self.v_scale = 1.0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
cache_k = (cache_k / k_scale).to(self.dtype)
|
||||
cache_v = (cache_v / v_scale).to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
@@ -277,6 +278,29 @@ class ModelRunner:
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
self.model.load_kv_cache_scales(
|
||||
self.server_args.quantization_param_path
|
||||
)
|
||||
logger.info(
|
||||
"Loaded KV cache scaling factors from %s",
|
||||
self.server_args.quantization_param_path,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using FP8 KV cache and scaling factors provided but "
|
||||
"model %s does not support loading scaling factors.",
|
||||
self.model.__class__,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Using FP8 KV cache but no scaling factors "
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!"
|
||||
)
|
||||
|
||||
# Parse other args
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
@@ -516,6 +540,9 @@ class ModelRunner:
|
||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if is_cuda():
|
||||
self.kv_cache_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||
|
||||
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type,
|
||||
):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if hasattr(layer_self_attn.attn, "k_scale"):
|
||||
layer_self_attn.attn.k_scale = scaling_factor
|
||||
layer_self_attn.attn.v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Self attention has no KV cache scaling " "factor attribute!"
|
||||
)
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_ipv6,
|
||||
is_port_available,
|
||||
nullable_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,6 +48,7 @@ class ServerArgs:
|
||||
trust_remote_code: bool = True
|
||||
dtype: str = "auto"
|
||||
kv_cache_dtype: str = "auto"
|
||||
quantization_param_path: nullable_str = None
|
||||
quantization: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
device: str = "cuda"
|
||||
@@ -350,8 +352,17 @@ class ServerArgs:
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
default=ServerArgs.kv_cache_dtype,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization-param-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="Path to the JSON file containing the KV cache "
|
||||
"scaling factors. This should generally be supplied, when "
|
||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||
"default to 1.0, which may cause accuracy issues. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
|
||||
@@ -1375,3 +1375,9 @@ def debug_timing(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
if not val or val == "None":
|
||||
return None
|
||||
return val
|
||||
|
||||
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 1,
|
||||
"1": 1,
|
||||
"2": 1,
|
||||
"3": 1,
|
||||
"4": 1,
|
||||
"5": 1,
|
||||
"6": 1,
|
||||
"7": 1,
|
||||
"8": 1,
|
||||
"9": 1,
|
||||
"10": 1,
|
||||
"11": 1,
|
||||
"12": 1,
|
||||
"13": 1,
|
||||
"14": 1,
|
||||
"15": 1,
|
||||
"16": 1,
|
||||
"17": 1,
|
||||
"18": 1,
|
||||
"19": 1,
|
||||
"20": 1,
|
||||
"21": 1,
|
||||
"22": 1,
|
||||
"23": 1,
|
||||
"24": 1,
|
||||
"25": 1,
|
||||
"26": 1,
|
||||
"27": 1,
|
||||
"28": 1,
|
||||
"29": 1,
|
||||
"30": 1,
|
||||
"31": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
64
test/srt/test_fp8_kvcache.py
Normal file
64
test/srt/test_fp8_kvcache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestFp8Kvcache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
dirpath = os.path.dirname(__file__)
|
||||
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--kv-cache-dtype",
|
||||
"fp8_e4m3",
|
||||
"--quantization-param-path",
|
||||
config_file,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.835)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user