diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 27b49f4ec..64efff14b 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module): return ans -def _enable_fused_set_kv_buffer(): - return _is_cuda +def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch): + """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" + return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 # TODO maybe move to a model-common utils @@ -341,7 +342,7 @@ class GptOssAttention(nn.Module): layer=self.attn, forward_batch=forward_batch, ) - if _enable_fused_set_kv_buffer() + if _enable_fused_set_kv_buffer(forward_batch) else None ), ) @@ -355,7 +356,7 @@ class GptOssAttention(nn.Module): attn_output = self.attn( *inner_state, sinks=self.sinks, - save_kv_cache=not _enable_fused_set_kv_buffer(), + save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch), ) output, _ = self.o_proj(attn_output) return output