diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 49565f6f0..4274acf43 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache( static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index fc6d8ea00..9e0b11c2f 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") - positions = positions.int() torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), k_rope=key.view(key.shape[0], -1, head_size), cos_sin_cache=cos_sin_cache, - pos_ids=positions, + pos_ids=positions.long(), interleave=(not is_neox), cuda_stream=get_cuda_stream(), )