From 668ecc6c5b375249578b83dbabfb47c3ed5d9dbd Mon Sep 17 00:00:00 2001 From: strgrb Date: Thu, 27 Mar 2025 23:27:51 +0800 Subject: [PATCH] Fix ut mla-test-1-gpu-amd (#4813) Co-authored-by: Zhang Kaihong --- .github/workflows/pr-test-amd.yml | 1 + python/sglang/srt/layers/rotary_embedding.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index a510a5ded..074b855e6 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -89,6 +89,7 @@ jobs: docker exec ci_sglang pip uninstall sgl-kernel -y || true docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install" docker exec ci_sglang pip install -e "python[dev_hip]" + docker exec ci_sglang pip install py-spy || true docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git docker exec -w /human-eval ci_sglang pip install -e . diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 648b97cf0..319e1c04d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -645,6 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if _is_cuda_available: + return self.forward_cuda(positions, query, key, offsets) + else: + return self.forward_native(positions, query, key, offsets) + def forward_native( self, positions: torch.Tensor,