Fix ut mla-test-1-gpu-amd (#4813)

Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
strgrb
2025-03-27 23:27:51 +08:00
committed by GitHub
parent 886fcbdd09
commit 668ecc6c5b
2 changed files with 13 additions and 0 deletions

View File

@@ -645,6 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda_available:
return self.forward_cuda(positions, query, key, offsets)
else:
return self.forward_native(positions, query, key, offsets)
def forward_native(
self,
positions: torch.Tensor,