Optimize rope in sgl kernel (#4267)

This commit is contained in:
Lianmin Zheng
2025-03-10 10:07:45 -07:00
committed by GitHub
parent 3d56585a97
commit cf0ccd406e
2 changed files with 2 additions and 3 deletions

View File

@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
positions = positions.int()
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
pos_ids=positions.long(),
interleave=(not is_neox),
cuda_stream=get_cuda_stream(),
)