From 07440f5f349ef6c4b216e5aa6ebd0827ba9ee2ee Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 28 Sep 2025 11:17:27 -0700 Subject: [PATCH] Fix FusedSetKVBufferArg in RotaryEmbedding (#11003) --- python/sglang/srt/layers/rotary_embedding.py | 31 +++++++++++++++---- .../sgl_kernel/testing/rotary_embedding.py | 10 ++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index cc4e21b58..2c7267529 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -27,7 +27,10 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() if _is_cuda: - from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace +else: + FusedSetKVBufferArg = None + if _use_aiter: from aiter.rotary_embedding import get_rope as aiter_get_rope @@ -146,8 +149,13 @@ class RotaryEmbedding(CustomOp): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-native implementation of forward().""" + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for native implementation" + if offsets is not None: positions = positions + offsets positions = positions.flatten() @@ -176,12 +184,17 @@ class RotaryEmbedding(CustomOp): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-npu implementation of forward().""" - import os + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for npu implementation" if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): - return self.forward_native(positions, query, key, offsets) + return self.forward_native( + positions, query, key, offsets, fused_set_kv_buffer_arg + ) else: rotary_mode = "half" if self.is_neox_style: @@ -206,8 +219,12 @@ class RotaryEmbedding(CustomOp): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - fused_set_kv_buffer_arg=None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for cpu implementation" + positions = torch.add(positions, offsets) if offsets is not None else positions if _is_cpu_amx_available: return torch.ops.sgl_kernel.rotary_embedding_cpu( @@ -219,7 +236,9 @@ class RotaryEmbedding(CustomOp): self.is_neox_style, ) else: - return self.forward_native(positions, query, key, offsets) + return self.forward_native( + positions, query, key, offsets, fused_set_kv_buffer_arg + ) def forward_cuda( self, @@ -227,7 +246,7 @@ class RotaryEmbedding(CustomOp): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg] + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if _is_cuda and (self.head_size in [64, 128, 256, 512]): apply_rope_with_cos_sin_cache_inplace( diff --git a/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py b/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py index e26208048..f1506479b 100644 --- a/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py +++ b/sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional, Tuple, Union -import pytest import torch from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace @@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-native implementation of forward().""" + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for native implementation" + if offsets is not None: positions = positions + offsets @@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding): positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: apply_rope_with_cos_sin_cache_inplace(