[kernel] Fix position ids in rope (#3173)
This commit is contained in:
@@ -196,3 +196,7 @@ def test_correctness(
|
||||
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user