[CPU] Optimize FP16 decode_attention_cpu (#10652)
This commit is contained in:
@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
|
||||
dtype = torch.bfloat16
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device):
|
||||
# This represents the number of tokens already in the sequence
|
||||
seq_len = 1024
|
||||
total_tokens = B * seq_len
|
||||
@@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase):
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(
|
||||
B, H_Q, H_KV, D, D_V, device=device
|
||||
)
|
||||
for dtype in [torch.bfloat16, torch.float16]:
|
||||
self._test_grouped_decode_attention_once(
|
||||
B, H_Q, H_KV, D, D_V, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
self._test_grouped_decode_attention("cpu")
|
||||
|
||||
Reference in New Issue
Block a user