support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Rain Jiang
2025-09-01 12:17:12 -07:00
committed by GitHub
parent 598c0bc19d
commit 9db8025376

View File

@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer():
return _is_cuda
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer()
if _enable_fused_set_kv_buffer(forward_batch)
else None
),
)
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(),
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output