From 5e0a9b0981ff99c15d044e72f225f60eadb2a50f Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 15 Apr 2025 06:22:43 +0800 Subject: [PATCH] Apply deepseek cuda rope (#5385) Co-authored-by: Yineng Zhang --- python/sglang/srt/layers/rotary_embedding.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index aadaf4e3e..88a491e47 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_hip(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward(self, *args, **kwargs): + if torch._dynamo.is_compiling: + return self.forward_native(*args, **kwargs) + if _is_cuda_available: + return self.forward_cuda(*args, **kwargs) + else: + return self.forward_native(*args, **kwargs) + + def forward_native( self, positions: torch.Tensor, query: torch.Tensor,