support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
return ans
|
||||
|
||||
|
||||
def _enable_fused_set_kv_buffer():
|
||||
return _is_cuda
|
||||
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
||||
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
||||
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||
|
||||
|
||||
# TODO maybe move to a model-common utils
|
||||
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
|
||||
layer=self.attn,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if _enable_fused_set_kv_buffer()
|
||||
if _enable_fused_set_kv_buffer(forward_batch)
|
||||
else None
|
||||
),
|
||||
)
|
||||
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
|
||||
attn_output = self.attn(
|
||||
*inner_state,
|
||||
sinks=self.sinks,
|
||||
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
||||
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user