Add main for merge state tests (#6492)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -136,3 +136,7 @@ def test_merge_state(seq_len, num_heads, head_dim):
|
|||||||
|
|
||||||
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
||||||
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
|
|||||||
@@ -394,3 +394,7 @@ def test_merge_attn_states(
|
|||||||
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
||||||
):
|
):
|
||||||
generate_markdown_table()
|
generate_markdown_table()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user