From cf0ccd406e38a63bdb984578ba742ca3c8ab81b8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 10 Mar 2025 10:07:45 -0700 Subject: [PATCH] Optimize rope in sgl kernel (#4267) --- sgl-kernel/csrc/elementwise/rope.cu | 2 +- sgl-kernel/python/sgl_kernel/elementwise.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 49565f6f0..4274acf43 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache( static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index fc6d8ea00..9e0b11c2f 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") - positions = positions.int() torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), k_rope=key.view(key.shape[0], -1, head_size), cos_sin_cache=cos_sin_cache, - pos_ids=positions, + pos_ids=positions.long(), interleave=(not is_neox), cuda_stream=get_cuda_stream(), )