From 9db8025376b201d5b29e25abaedd742e7d5615c7 Mon Sep 17 00:00:00 2001 From: Rain Jiang <96632942+rainj-me@users.noreply.github.com> Date: Mon, 1 Sep 2025 12:17:12 -0700 Subject: [PATCH] support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/models/gpt_oss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 27b49f4ec..64efff14b 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module): return ans -def _enable_fused_set_kv_buffer(): - return _is_cuda +def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch): + """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" + return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 # TODO maybe move to a model-common utils @@ -341,7 +342,7 @@ class GptOssAttention(nn.Module): layer=self.attn, forward_batch=forward_batch, ) - if _enable_fused_set_kv_buffer() + if _enable_fused_set_kv_buffer(forward_batch) else None ), ) @@ -355,7 +356,7 @@ class GptOssAttention(nn.Module): attn_output = self.attn( *inner_state, sinks=self.sinks, - save_kv_cache=not _enable_fused_set_kv_buffer(), + save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch), ) output, _ = self.o_proj(attn_output) return output