From a1c1ebe9357418ae8cf8ffdf66e3eaec066170e8 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Wed, 25 Jun 2025 17:14:40 +0800 Subject: [PATCH] Fix FP8 KV Cache Support in FA3 Backend (#7148) --- .../attention/flashattention_backend.py | 38 +++++++++------ .../sglang/srt/model_executor/model_runner.py | 6 ++- test/srt/test_mla_deepseek_v3.py | 46 ++++++++++++++++++- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9cbdce35c..85899636e 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend): ) k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention - # has corresponding quantization method so that layer.k_scale is not None - if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None: - descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) - k_descale = layer.k_scale.expand(descale_shape) - v_descale = layer.v_scale.expand(descale_shape) + # has corresponding quantization method so that layer.k_scale is not None, + # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. + if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: + if layer.k_scale is not None: + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) q = q.to(self.kv_cache_dtype) + q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None causal = not layer.is_cross_attention # Check if we should use local attention @@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend): output, lse, *rest = flash_attn_varlen_func( q=q.view(-1, layer.tp_q_head_num, layer.head_dim), - k=k.view(-1, layer.tp_k_head_num, layer.head_dim), - v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], max_seqlen_q=metadata.max_seq_len_q, @@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend): # MHA for extend part of sequence without attending prefix kv cache output, lse, *rest = flash_attn_varlen_func( q=q.view(-1, layer.tp_q_head_num, layer.head_dim), - k=k.view(-1, layer.tp_k_head_num, layer.head_dim), - v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k=metadata.cu_seqlens_q, max_seqlen_q=metadata.max_seq_len_q, @@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend): return output, lse else: # Do absorbed multi-latent attention - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).to(q.dtype) k_rope = kv_cache[:, :, layer.v_head_dim :] c_kv = kv_cache[:, :, : layer.v_head_dim] k_rope_cache = k_rope.view( @@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend): k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention - # has corresponding quantization method so that layer.k_scale is not None - if self.kv_cache_dtype_str != "auto": + # has corresponding quantization method so that layer.k_scale is not None, + # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. + if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: if layer.k_scale is not None: descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) k_descale = layer.k_scale.expand(descale_shape) v_descale = layer.v_scale.expand(descale_shape) q = q.to(self.kv_cache_dtype) - + q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None if not self.use_mla: # Do multi-head attention @@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend): o = result else: # Do absorbed multi-latent attention - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + q.dtype + ) k_rope = kv_cache[:, :, layer.v_head_dim :] c_kv = kv_cache[:, :, : layer.v_head_dim] k_rope_cache = k_rope.view( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bd6a027d5..21f4b968d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -239,7 +239,7 @@ class ModelRunner: "SGLANG_LOG_EXPERT_LOCATION_METADATA" ): logger.info( - f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" + f"Initial expert_location_metadata: {get_global_expert_location_metadata()}" ) set_global_expert_distribution_recorder( @@ -866,7 +866,9 @@ class ModelRunner: else: self.kv_cache_dtype = torch.float8_e5m2 elif self.server_args.kv_cache_dtype == "fp8_e4m3": - if is_cuda(): + if _is_hip: # Using natively supported format + self.kv_cache_dtype = torch.float8_e4m3fnuz + else: self.kv_cache_dtype = torch.float8_e4m3fn else: raise ValueError( diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py index 1990b0857..c2d659a78 100644 --- a/test/srt/test_mla_deepseek_v3.py +++ b/test/srt/test_mla_deepseek_v3.py @@ -4,7 +4,7 @@ from types import SimpleNamespace import requests import torch -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import is_cuda, is_hip, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -20,7 +20,7 @@ class TestMLADeepseekV3(CustomTestCase): cls.model = "lmsys/sglang-ci-dsv3-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code", "--chunked-prefill-size", "256"] - if torch.cuda.is_available() and torch.version.cuda: + if is_cuda(): other_args.extend(["--enable-torch-compile", "--cuda-graph-max-bs", "2"]) cls.process = popen_launch_server( cls.model, @@ -49,6 +49,48 @@ class TestMLADeepseekV3(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.62) +@unittest.skipIf(is_hip(), "FA is not available.") +class TestMLADeepseekV3Fa3Fp8Kvcache(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--chunked-prefill-size", + "256", + "--kv-cache-dtype", + "fp8_e4m3", + ] + if is_cuda(): + other_args.extend(["--attention-backend", "fa3"]) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + class TestDeepseekV3MTP(CustomTestCase): @classmethod def setUpClass(cls):