From 74de76c685b71d7d0d30fab8bb1adaa110935cc9 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Tue, 21 Oct 2025 10:38:29 +0800 Subject: [PATCH] Revise MRotaryEmbedding's forward (#11859) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: luoyuan.luo Co-authored-by: 羽癫 Co-authored-by: b8zhong --- python/sglang/srt/layers/rotary_embedding.py | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 6ce25651f..cfc626e1d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) @torch.compile(dynamic=True, backend=get_compiler_backend()) - def forward_native( + def _forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding): query: torch.Tensor, key: torch.Tensor, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass with optional Triton kernel acceleration. + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + if positions.ndim == 2 and self.mrope_section and _is_cuda: + return self._forward_triton(positions, query, key) + else: + return self._forward_native(positions, query, key) + + def _forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: assert positions.ndim == 1 or positions.ndim == 2 assert key is not None