[Fix] Fix logprob and normalized_logprob (#1428)

This commit is contained in:
Lianmin Zheng
2024-09-15 06:36:06 -07:00
committed by GitHub
parent 282681b8a1
commit 9ba1f09760
22 changed files with 314 additions and 215 deletions

View File

@@ -127,7 +127,6 @@ class TestExtendAttention(unittest.TestCase):
def _test_context_attention_once(self, head_dim):
# Set up a simple test case
batch_size = 2
num_heads = 4
seq_lens = [8, 12]
max_seq_len = max(seq_lens)