Apply deepseek cuda rope (#5385)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Ke Bao
2025-04-15 06:22:43 +08:00
committed by GitHub
parent bdde237562
commit 5e0a9b0981

View File

@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward(self, *args, **kwargs):
if torch._dynamo.is_compiling:
return self.forward_native(*args, **kwargs)
if _is_cuda_available:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,