From 8ae9d4bb41d47fab09d6cd5ee5fb4c513be32c07 Mon Sep 17 00:00:00 2001 From: b8zhong Date: Thu, 23 Oct 2025 12:42:59 -0700 Subject: [PATCH] Revert "[ROCm] Remove vLLM rope dependency & use AITER impl" (#12028) --- python/sglang/srt/layers/rotary_embedding.py | 120 ---------- test/srt/test_rope_rocm.py | 233 ------------------- 2 files changed, 353 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index fb7d9c12b..2c8181ebe 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -124,23 +124,6 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self._hip_cached_cos: Optional[torch.Tensor] = None - self._hip_cached_sin: Optional[torch.Tensor] = None - if _use_aiter: - half_rotary = cache.shape[-1] // 2 - cos_cache = ( - cache[:, :half_rotary] - .contiguous() - .view(self.max_position_embeddings, 1, 1, half_rotary) - ) - sin_cache = ( - cache[:, half_rotary:] - .contiguous() - .view(self.max_position_embeddings, 1, 1, half_rotary) - ) - self.register_buffer("_hip_cos_cache", cos_cache, persistent=False) - self.register_buffer("_hip_sin_cache", sin_cache, persistent=False) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -201,109 +184,6 @@ class RotaryEmbedding(CustomOp): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - def forward_hip( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor], - offsets: Optional[torch.Tensor] = None, - fused_set_kv_buffer_arg: Optional["FusedSetKVBufferArg"] = None, - *, - is_nope_first: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if not _use_aiter: - return self.forward_native( - positions, query, key, offsets, fused_set_kv_buffer_arg - ) - - if fused_set_kv_buffer_arg is not None: - raise NotImplementedError( - "fused_set_kv_buffer_arg is not supported for HIP path" - ) - - import aiter as ops - - if not hasattr(self, "_hip_cos_cache") or not hasattr(self, "_hip_sin_cache"): - raise RuntimeError("HIP caches not initialised") - - cos = self._hip_cached_cos - sin = self._hip_cached_sin - if cos is None or cos.device != query.device or cos.dtype != query.dtype: - cos = self._hip_cos_cache.to(query.device, dtype=query.dtype) - sin = self._hip_sin_cache.to(query.device, dtype=query.dtype) - self._hip_cached_cos = cos - self._hip_cached_sin = sin - - rotate_style = 0 if self.is_neox_style else 1 - - num_tokens = positions.numel() - query_shape = query.shape - query = query.view(1, num_tokens, -1, self.head_size) - key_shape = key.shape if key is not None else None - if key is not None: - key = key.view(1, num_tokens, -1, self.head_size) - - positions = positions.view(*query.shape[:2]) - if offsets is not None: - offsets = offsets.view(*query.shape[:2]) - - if not is_nope_first: - query_rot = query[..., : self.rotary_dim] - key_rot = key[..., : self.rotary_dim] if key is not None else None - else: - query_rot = query[..., -self.rotary_dim :] - key_rot = key[..., -self.rotary_dim :] if key is not None else None - - if key_rot is None: - if offsets is None: - ops.rope_cached_positions_fwd_inplace( - query_rot, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_fwd_inplace( - query_rot, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - return query.view(query_shape), None - - if offsets is None: - ops.rope_cached_positions_2c_fwd_inplace( - query_rot, - key_rot, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_2c_fwd_inplace( - query_rot, - key_rot, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - return query.view(query_shape), key.view(key_shape) if key is not None else None - def forward_npu( self, positions: torch.Tensor, diff --git a/test/srt/test_rope_rocm.py b/test/srt/test_rope_rocm.py index a6ee7adfe..5850e7061 100644 --- a/test/srt/test_rope_rocm.py +++ b/test/srt/test_rope_rocm.py @@ -111,239 +111,6 @@ class TestRotaryEmbeddingAITer(CustomTestCase): with self.subTest(case=case): self._run_case_aiter(*case) - def test_ops_equivalence_basic(self) -> None: - import aiter as ops - from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding - - ( - head_size, - rotary_dim, - max_pos, - base, - is_neox, - dtype, - device, - bs, - seq_len, - num_q, - num_kv, - ) = ( - 128, - 64, - 2048, - 10000, - True, - torch.bfloat16, - "cuda", - 2, - 32, - 4, - 2, - ) - - rope = AiterRotaryEmbedding( - head_size, rotary_dim, max_pos, base, is_neox, dtype - ).to(device) - - positions = torch.arange(seq_len, device=device).repeat(bs) - num_tokens = positions.numel() - - q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device) - k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device) - - q_ref, k_ref = rope.forward_hip(positions.clone(), q2d.clone(), k2d.clone()) - - q_sbhd = q2d.view(1, num_tokens, num_q, head_size) - k_sbhd = k2d.view(1, num_tokens, num_kv, head_size) - - cos = rope.cos_cache.to(device=device, dtype=dtype) - sin = rope.sin_cache.to(device=device, dtype=dtype) - pos_b_s = positions.view(1, num_tokens) - rotate_style = 0 if is_neox else 1 - ops.rope_cached_positions_2c_fwd_inplace( - q_sbhd, - k_sbhd, - cos, - sin, - pos_b_s, - rotate_style, - reuse_freqs_front_part=True, - nope_first=False, - ) - - self.assertTrue(q_ref.shape == q2d.shape) - self.assertTrue(k_ref.shape == k2d.shape) - torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2) - torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2) - - def test_ops_equivalence_nope_first(self) -> None: - import aiter as ops - from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding - - ( - head_size, - rotary_dim, - max_pos, - base, - is_neox, - dtype, - device, - bs, - seq_len, - num_q, - num_kv, - ) = ( - 128, - 64, - 2048, - 10000, - True, - torch.bfloat16, - "cuda", - 1, - 16, - 2, - 2, - ) - - rope = AiterRotaryEmbedding( - head_size, rotary_dim, max_pos, base, is_neox, dtype - ).to(device) - - positions = torch.arange(seq_len, device=device).repeat(bs) - num_tokens = positions.numel() - - q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device) - k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device) - - q_ref, k_ref = rope.forward_hip( - positions.clone(), q2d.clone(), k2d.clone(), is_nope_first=True - ) - - q_sbhd = q2d.view(1, num_tokens, num_q, head_size) - k_sbhd = k2d.view(1, num_tokens, num_kv, head_size) - - cos = rope.cos_cache.to(device=device, dtype=dtype) - sin = rope.sin_cache.to(device=device, dtype=dtype) - pos_b_s = positions.view(1, num_tokens) - rotate_style = 0 if is_neox else 1 - - q_rot = q_sbhd[..., -rotary_dim:] - k_rot = k_sbhd[..., -rotary_dim:] - ops.rope_cached_positions_2c_fwd_inplace( - q_rot, - k_rot, - cos, - sin, - pos_b_s, - rotate_style, - reuse_freqs_front_part=True, - nope_first=True, - ) - - torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2) - torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2) - - def test_sglang_rotary_embedding_forward_hip_matches_native(self) -> None: - from sglang.srt.layers.rotary_embedding import ( - RotaryEmbedding as SglRotaryEmbedding, - ) - - ( - head_size, - rotary_dim, - max_pos, - base, - is_neox, - dtype, - device, - bs, - seq_len, - num_q, - num_kv, - ) = ( - 128, - 64, - 2048, - 10000, - True, - torch.bfloat16, - "cuda", - 2, - 64, - 4, - 2, - ) - - rope = SglRotaryEmbedding( - head_size, rotary_dim, max_pos, base, is_neox, dtype - ).to(device) - - positions = torch.arange(seq_len, device=device).repeat(bs) - q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device) - k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device) - - q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone()) - q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone()) - - torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) - - def test_llama3_rotary_embedding_forward_hip_matches_native(self) -> None: - from sglang.srt.layers.rotary_embedding import get_rope as sgl_get_rope - - ( - head_size, - rotary_dim, - max_pos, - base, - is_neox, - dtype, - device, - bs, - seq_len, - num_q, - num_kv, - ) = ( - 128, - 128, - 2048, - 10000, - True, - torch.bfloat16, - "cuda", - 2, - 64, - 4, - 2, - ) - - rope = sgl_get_rope( - head_size, - rotary_dim, - max_pos, - base, - is_neox, - rope_scaling={ - "rope_type": "llama3", - "factor": 1.0, - "low_freq_factor": 1.0, - "high_freq_factor": 1.0, - "original_max_position_embeddings": max_pos, - }, - dtype=dtype, - ).to(device) - - positions = torch.arange(seq_len, device=device).repeat(bs) - q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device) - k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device) - - q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone()) - q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone()) - - torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) - if __name__ == "__main__": unittest.main()