diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 6ce25651f..cfc626e1d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) @torch.compile(dynamic=True, backend=get_compiler_backend()) - def forward_native( + def _forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding): query: torch.Tensor, key: torch.Tensor, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass with optional Triton kernel acceleration. + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + if positions.ndim == 2 and self.mrope_section and _is_cuda: + return self._forward_triton(positions, query, key) + else: + return self._forward_native(positions, query, key) + + def _forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: assert positions.ndim == 1 or positions.ndim == 2 assert key is not None