[CPU] Optimize FP16 decode_attention_cpu (#10652)

This commit is contained in:
blzheng
2025-10-23 12:39:51 +08:00
committed by GitHub
parent 81fd2b0ee0
commit 13fb8b5489
4 changed files with 181 additions and 9 deletions

View File

@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
dtype = torch.bfloat16
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device):
# This represents the number of tokens already in the sequence
seq_len = 1024
total_tokens = B * seq_len
@@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase):
]
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(
B, H_Q, H_KV, D, D_V, device=device
)
for dtype in [torch.bfloat16, torch.float16]:
self._test_grouped_decode_attention_once(
B, H_Q, H_KV, D, D_V, dtype=dtype, device=device
)
def test_grouped_decode_attention(self):
self._test_grouped_decode_attention("cpu")