diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 2c8181ebe..fb7d9c12b 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -124,6 +124,23 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) + self._hip_cached_cos: Optional[torch.Tensor] = None + self._hip_cached_sin: Optional[torch.Tensor] = None + if _use_aiter: + half_rotary = cache.shape[-1] // 2 + cos_cache = ( + cache[:, :half_rotary] + .contiguous() + .view(self.max_position_embeddings, 1, 1, half_rotary) + ) + sin_cache = ( + cache[:, half_rotary:] + .contiguous() + .view(self.max_position_embeddings, 1, 1, half_rotary) + ) + self.register_buffer("_hip_cos_cache", cos_cache, persistent=False) + self.register_buffer("_hip_sin_cache", sin_cache, persistent=False) + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -184,6 +201,109 @@ class RotaryEmbedding(CustomOp): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor], + offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg: Optional["FusedSetKVBufferArg"] = None, + *, + is_nope_first: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if not _use_aiter: + return self.forward_native( + positions, query, key, offsets, fused_set_kv_buffer_arg + ) + + if fused_set_kv_buffer_arg is not None: + raise NotImplementedError( + "fused_set_kv_buffer_arg is not supported for HIP path" + ) + + import aiter as ops + + if not hasattr(self, "_hip_cos_cache") or not hasattr(self, "_hip_sin_cache"): + raise RuntimeError("HIP caches not initialised") + + cos = self._hip_cached_cos + sin = self._hip_cached_sin + if cos is None or cos.device != query.device or cos.dtype != query.dtype: + cos = self._hip_cos_cache.to(query.device, dtype=query.dtype) + sin = self._hip_sin_cache.to(query.device, dtype=query.dtype) + self._hip_cached_cos = cos + self._hip_cached_sin = sin + + rotate_style = 0 if self.is_neox_style else 1 + + num_tokens = positions.numel() + query_shape = query.shape + query = query.view(1, num_tokens, -1, self.head_size) + key_shape = key.shape if key is not None else None + if key is not None: + key = key.view(1, num_tokens, -1, self.head_size) + + positions = positions.view(*query.shape[:2]) + if offsets is not None: + offsets = offsets.view(*query.shape[:2]) + + if not is_nope_first: + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] if key is not None else None + else: + query_rot = query[..., -self.rotary_dim :] + key_rot = key[..., -self.rotary_dim :] if key is not None else None + + if key_rot is None: + if offsets is None: + ops.rope_cached_positions_fwd_inplace( + query_rot, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + else: + ops.rope_cached_positions_offsets_fwd_inplace( + query_rot, + cos, + sin, + positions, + offsets, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + return query.view(query_shape), None + + if offsets is None: + ops.rope_cached_positions_2c_fwd_inplace( + query_rot, + key_rot, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + else: + ops.rope_cached_positions_offsets_2c_fwd_inplace( + query_rot, + key_rot, + cos, + sin, + positions, + offsets, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + + return query.view(query_shape), key.view(key_shape) if key is not None else None + def forward_npu( self, positions: torch.Tensor, diff --git a/test/srt/test_rope_rocm.py b/test/srt/test_rope_rocm.py index 5850e7061..a6ee7adfe 100644 --- a/test/srt/test_rope_rocm.py +++ b/test/srt/test_rope_rocm.py @@ -111,6 +111,239 @@ class TestRotaryEmbeddingAITer(CustomTestCase): with self.subTest(case=case): self._run_case_aiter(*case) + def test_ops_equivalence_basic(self) -> None: + import aiter as ops + from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding + + ( + head_size, + rotary_dim, + max_pos, + base, + is_neox, + dtype, + device, + bs, + seq_len, + num_q, + num_kv, + ) = ( + 128, + 64, + 2048, + 10000, + True, + torch.bfloat16, + "cuda", + 2, + 32, + 4, + 2, + ) + + rope = AiterRotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + + positions = torch.arange(seq_len, device=device).repeat(bs) + num_tokens = positions.numel() + + q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device) + k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device) + + q_ref, k_ref = rope.forward_hip(positions.clone(), q2d.clone(), k2d.clone()) + + q_sbhd = q2d.view(1, num_tokens, num_q, head_size) + k_sbhd = k2d.view(1, num_tokens, num_kv, head_size) + + cos = rope.cos_cache.to(device=device, dtype=dtype) + sin = rope.sin_cache.to(device=device, dtype=dtype) + pos_b_s = positions.view(1, num_tokens) + rotate_style = 0 if is_neox else 1 + ops.rope_cached_positions_2c_fwd_inplace( + q_sbhd, + k_sbhd, + cos, + sin, + pos_b_s, + rotate_style, + reuse_freqs_front_part=True, + nope_first=False, + ) + + self.assertTrue(q_ref.shape == q2d.shape) + self.assertTrue(k_ref.shape == k2d.shape) + torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2) + + def test_ops_equivalence_nope_first(self) -> None: + import aiter as ops + from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding + + ( + head_size, + rotary_dim, + max_pos, + base, + is_neox, + dtype, + device, + bs, + seq_len, + num_q, + num_kv, + ) = ( + 128, + 64, + 2048, + 10000, + True, + torch.bfloat16, + "cuda", + 1, + 16, + 2, + 2, + ) + + rope = AiterRotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + + positions = torch.arange(seq_len, device=device).repeat(bs) + num_tokens = positions.numel() + + q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device) + k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device) + + q_ref, k_ref = rope.forward_hip( + positions.clone(), q2d.clone(), k2d.clone(), is_nope_first=True + ) + + q_sbhd = q2d.view(1, num_tokens, num_q, head_size) + k_sbhd = k2d.view(1, num_tokens, num_kv, head_size) + + cos = rope.cos_cache.to(device=device, dtype=dtype) + sin = rope.sin_cache.to(device=device, dtype=dtype) + pos_b_s = positions.view(1, num_tokens) + rotate_style = 0 if is_neox else 1 + + q_rot = q_sbhd[..., -rotary_dim:] + k_rot = k_sbhd[..., -rotary_dim:] + ops.rope_cached_positions_2c_fwd_inplace( + q_rot, + k_rot, + cos, + sin, + pos_b_s, + rotate_style, + reuse_freqs_front_part=True, + nope_first=True, + ) + + torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2) + + def test_sglang_rotary_embedding_forward_hip_matches_native(self) -> None: + from sglang.srt.layers.rotary_embedding import ( + RotaryEmbedding as SglRotaryEmbedding, + ) + + ( + head_size, + rotary_dim, + max_pos, + base, + is_neox, + dtype, + device, + bs, + seq_len, + num_q, + num_kv, + ) = ( + 128, + 64, + 2048, + 10000, + True, + torch.bfloat16, + "cuda", + 2, + 64, + 4, + 2, + ) + + rope = SglRotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + + positions = torch.arange(seq_len, device=device).repeat(bs) + q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device) + k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device) + + q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone()) + q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone()) + + torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) + + def test_llama3_rotary_embedding_forward_hip_matches_native(self) -> None: + from sglang.srt.layers.rotary_embedding import get_rope as sgl_get_rope + + ( + head_size, + rotary_dim, + max_pos, + base, + is_neox, + dtype, + device, + bs, + seq_len, + num_q, + num_kv, + ) = ( + 128, + 128, + 2048, + 10000, + True, + torch.bfloat16, + "cuda", + 2, + 64, + 4, + 2, + ) + + rope = sgl_get_rope( + head_size, + rotary_dim, + max_pos, + base, + is_neox, + rope_scaling={ + "rope_type": "llama3", + "factor": 1.0, + "low_freq_factor": 1.0, + "high_freq_factor": 1.0, + "original_max_position_embeddings": max_pos, + }, + dtype=dtype, + ).to(device) + + positions = torch.arange(seq_len, device=device).repeat(bs) + q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device) + k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device) + + q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone()) + q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone()) + + torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) + if __name__ == "__main__": unittest.main()