diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index fc3455b60..f03839462 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend): if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend): sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend): o, _ = merge_state(o1, s1, o2, s2) if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend): if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 4b762c00b..a449d7188 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -47,6 +47,8 @@ class RadixAttention(nn.Module): self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention + self.k_scale = 1.0 + self.v_scale = 1.0 def forward( self, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b67f085b2..6cb186577 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -109,8 +109,8 @@ class BaseTokenToKVPool: ): self.size = size self.dtype = dtype - if dtype == torch.float8_e5m2: - # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype @@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool): loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) + cache_k = (cache_k / k_scale).to(self.dtype) + cache_v = (cache_v / v_scale).to(self.dtype) if self.store_dtype != self.dtype: self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index efba8c25b..d46a2c0dc 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -54,6 +54,7 @@ from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, init_custom_process_group, + is_cuda, is_hip, monkey_patch_vllm_gguf_config, monkey_patch_vllm_p2p_access_check, @@ -277,6 +278,29 @@ class ModelRunner: device_config=DeviceConfig(self.device), ) + if self.server_args.kv_cache_dtype == "fp8_e4m3": + if self.server_args.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.server_args.quantization_param_path + ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.server_args.quantization_param_path, + ) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) + # Parse other args self.sliding_window_size = ( self.model.get_attention_sliding_window_size() @@ -516,6 +540,9 @@ class ModelRunner: self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 + elif self.server_args.kv_cache_dtype == "fp8_e4m3": + if is_cuda(): + self.kv_cache_dtype = torch.float8_e4m3fn else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e1688df01..d606e52f8 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -299,6 +303,30 @@ class LlamaModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + class LlamaForCausalLM(nn.Module): @@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module): torch.cuda.empty_cache() torch.cuda.synchronize() + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 66739652a..be85a3670 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -32,6 +32,7 @@ from sglang.srt.utils import ( is_hip, is_ipv6, is_port_available, + nullable_str, ) logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ class ServerArgs: trust_remote_code: bool = True dtype: str = "auto" kv_cache_dtype: str = "auto" + quantization_param_path: nullable_str = None quantization: Optional[str] = None context_length: Optional[int] = None device: str = "cuda" @@ -350,8 +352,17 @@ class ServerArgs: "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, - choices=["auto", "fp8_e5m2"], - help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + choices=["auto", "fp8_e5m2", "fp8_e4m3"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', + ) + parser.add_argument( + "--quantization-param-path", + type=nullable_str, + default=None, + help="Path to the JSON file containing the KV cache " + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--quantization", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b07f6f01d..af9bdd60b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1375,3 +1375,9 @@ def debug_timing(func): return func(*args, **kwargs) return wrapper + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val diff --git a/test/srt/kv_cache_scales_llama3_1_8b.json b/test/srt/kv_cache_scales_llama3_1_8b.json new file mode 100644 index 000000000..3e890e50e --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_1_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 1, + "1": 1, + "2": 1, + "3": 1, + "4": 1, + "5": 1, + "6": 1, + "7": 1, + "8": 1, + "9": 1, + "10": 1, + "11": 1, + "12": 1, + "13": 1, + "14": 1, + "15": 1, + "16": 1, + "17": 1, + "18": 1, + "19": 1, + "20": 1, + "21": 1, + "22": 1, + "23": 1, + "24": 1, + "25": 1, + "26": 1, + "27": 1, + "28": 1, + "29": 1, + "30": 1, + "31": 1 + } + } + } +} diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py new file mode 100644 index 000000000..0d6602997 --- /dev/null +++ b/test/srt/test_fp8_kvcache.py @@ -0,0 +1,64 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestFp8Kvcache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + dirpath = os.path.dirname(__file__) + config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json") + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--kv-cache-dtype", + "fp8_e4m3", + "--quantization-param-path", + config_file, + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.835) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main()