Support FP8 E4M3 KV Cache (#2786)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
o, _ = merge_state(o1, s1, o2, s2)
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
|
|||||||
self.logit_cap = logit_cap
|
self.logit_cap = logit_cap
|
||||||
self.sliding_window_size = sliding_window_size or -1
|
self.sliding_window_size = sliding_window_size or -1
|
||||||
self.is_cross_attention = is_cross_attention
|
self.is_cross_attention = is_cross_attention
|
||||||
|
self.k_scale = 1.0
|
||||||
|
self.v_scale = 1.0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
|
|||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
if dtype == torch.float8_e5m2:
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||||
self.store_dtype = torch.uint8
|
self.store_dtype = torch.uint8
|
||||||
else:
|
else:
|
||||||
self.store_dtype = dtype
|
self.store_dtype = dtype
|
||||||
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = cache_k.to(self.dtype)
|
cache_k = (cache_k / k_scale).to(self.dtype)
|
||||||
cache_v = cache_v.to(self.dtype)
|
cache_v = (cache_v / v_scale).to(self.dtype)
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
|
|||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
@@ -277,6 +278,29 @@ class ModelRunner:
|
|||||||
device_config=DeviceConfig(self.device),
|
device_config=DeviceConfig(self.device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||||
|
if self.server_args.quantization_param_path is not None:
|
||||||
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||||
|
self.model.load_kv_cache_scales(
|
||||||
|
self.server_args.quantization_param_path
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Loaded KV cache scaling factors from %s",
|
||||||
|
self.server_args.quantization_param_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Using FP8 KV cache and scaling factors provided but "
|
||||||
|
"model %s does not support loading scaling factors.",
|
||||||
|
self.model.__class__,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Using FP8 KV cache but no scaling factors "
|
||||||
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
|
"This may lead to less accurate results!"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse other args
|
# Parse other args
|
||||||
self.sliding_window_size = (
|
self.sliding_window_size = (
|
||||||
self.model.get_attention_sliding_window_size()
|
self.model.get_attention_sliding_window_size()
|
||||||
@@ -516,6 +540,9 @@ class ModelRunner:
|
|||||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||||
else:
|
else:
|
||||||
self.kv_cache_dtype = torch.float8_e5m2
|
self.kv_cache_dtype = torch.float8_e5m2
|
||||||
|
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||||
|
if is_cuda():
|
||||||
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||||
|
|||||||
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
# If this function is called, it should always initialize KV cache scale
|
||||||
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||||
|
quantization_param_path,
|
||||||
|
tp_rank,
|
||||||
|
tp_size,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.__class__.model_type,
|
||||||
|
):
|
||||||
|
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||||
|
layer_self_attn = self.layers[layer_idx].self_attn
|
||||||
|
|
||||||
|
if hasattr(layer_self_attn.attn, "k_scale"):
|
||||||
|
layer_self_attn.attn.k_scale = scaling_factor
|
||||||
|
layer_self_attn.attn.v_scale = scaling_factor
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Self attention has no KV cache scaling " "factor attribute!"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
self.model.load_kv_cache_scales(quantization_param_path)
|
||||||
|
|
||||||
|
|
||||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_ipv6,
|
is_ipv6,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
|
nullable_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -47,6 +48,7 @@ class ServerArgs:
|
|||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
|
quantization_param_path: nullable_str = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
@@ -350,8 +352,17 @@ class ServerArgs:
|
|||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.kv_cache_dtype,
|
default=ServerArgs.kv_cache_dtype,
|
||||||
choices=["auto", "fp8_e5m2"],
|
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantization-param-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the JSON file containing the KV cache "
|
||||||
|
"scaling factors. This should generally be supplied, when "
|
||||||
|
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||||
|
"default to 1.0, which may cause accuracy issues. ",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
|
|||||||
@@ -1375,3 +1375,9 @@ def debug_timing(func):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def nullable_str(val: str):
|
||||||
|
if not val or val == "None":
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|||||||
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
{
|
||||||
|
"model_type": "llama",
|
||||||
|
"kv_cache": {
|
||||||
|
"dtype": "float8_e4m3fn",
|
||||||
|
"scaling_factor": {
|
||||||
|
"0": {
|
||||||
|
"0": 1,
|
||||||
|
"1": 1,
|
||||||
|
"2": 1,
|
||||||
|
"3": 1,
|
||||||
|
"4": 1,
|
||||||
|
"5": 1,
|
||||||
|
"6": 1,
|
||||||
|
"7": 1,
|
||||||
|
"8": 1,
|
||||||
|
"9": 1,
|
||||||
|
"10": 1,
|
||||||
|
"11": 1,
|
||||||
|
"12": 1,
|
||||||
|
"13": 1,
|
||||||
|
"14": 1,
|
||||||
|
"15": 1,
|
||||||
|
"16": 1,
|
||||||
|
"17": 1,
|
||||||
|
"18": 1,
|
||||||
|
"19": 1,
|
||||||
|
"20": 1,
|
||||||
|
"21": 1,
|
||||||
|
"22": 1,
|
||||||
|
"23": 1,
|
||||||
|
"24": 1,
|
||||||
|
"25": 1,
|
||||||
|
"26": 1,
|
||||||
|
"27": 1,
|
||||||
|
"28": 1,
|
||||||
|
"29": 1,
|
||||||
|
"30": 1,
|
||||||
|
"31": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
64
test/srt/test_fp8_kvcache.py
Normal file
64
test/srt/test_fp8_kvcache.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFp8Kvcache(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
dirpath = os.path.dirname(__file__)
|
||||||
|
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
"fp8_e4m3",
|
||||||
|
"--quantization-param-path",
|
||||||
|
config_file,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreater(metrics["score"], 0.835)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user