support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def _enable_fused_set_kv_buffer():
|
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
||||||
return _is_cuda
|
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
||||||
|
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
# TODO maybe move to a model-common utils
|
# TODO maybe move to a model-common utils
|
||||||
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
|
|||||||
layer=self.attn,
|
layer=self.attn,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if _enable_fused_set_kv_buffer()
|
if _enable_fused_set_kv_buffer(forward_batch)
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
|
|||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
*inner_state,
|
*inner_state,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
||||||
)
|
)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user