[kernel] Fix position ids in rope (#3173)
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.0.2.post19"
|
version = "0.0.2.post20"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
|||||||
raise ValueError("cos_sin_cache should be float32")
|
raise ValueError("cos_sin_cache should be float32")
|
||||||
|
|
||||||
with query.device as device:
|
with query.device as device:
|
||||||
pos_ids = pos_ids.int()
|
positions = positions.int()
|
||||||
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
||||||
q=query.view(query.shape[0], -1, head_size),
|
q=query.view(query.shape[0], -1, head_size),
|
||||||
k=key.view(key.shape[0], -1, head_size),
|
k=key.view(key.shape[0], -1, head_size),
|
||||||
|
|||||||
@@ -196,3 +196,7 @@ def test_correctness(
|
|||||||
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
|
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "0.0.2.post19"
|
__version__ = "0.0.2.post20"
|
||||||
|
|||||||
Reference in New Issue
Block a user