Revert "[ROCm] Remove vLLM rope dependency & use AITER impl" (#12028)

This commit is contained in:
b8zhong
2025-10-23 12:42:59 -07:00
committed by GitHub
parent 1c304aa9bc
commit 8ae9d4bb41
2 changed files with 0 additions and 353 deletions

View File

@@ -111,239 +111,6 @@ class TestRotaryEmbeddingAITer(CustomTestCase):
with self.subTest(case=case):
self._run_case_aiter(*case)
def test_ops_equivalence_basic(self) -> None:
import aiter as ops
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
(
head_size,
rotary_dim,
max_pos,
base,
is_neox,
dtype,
device,
bs,
seq_len,
num_q,
num_kv,
) = (
128,
64,
2048,
10000,
True,
torch.bfloat16,
"cuda",
2,
32,
4,
2,
)
rope = AiterRotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
positions = torch.arange(seq_len, device=device).repeat(bs)
num_tokens = positions.numel()
q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device)
k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device)
q_ref, k_ref = rope.forward_hip(positions.clone(), q2d.clone(), k2d.clone())
q_sbhd = q2d.view(1, num_tokens, num_q, head_size)
k_sbhd = k2d.view(1, num_tokens, num_kv, head_size)
cos = rope.cos_cache.to(device=device, dtype=dtype)
sin = rope.sin_cache.to(device=device, dtype=dtype)
pos_b_s = positions.view(1, num_tokens)
rotate_style = 0 if is_neox else 1
ops.rope_cached_positions_2c_fwd_inplace(
q_sbhd,
k_sbhd,
cos,
sin,
pos_b_s,
rotate_style,
reuse_freqs_front_part=True,
nope_first=False,
)
self.assertTrue(q_ref.shape == q2d.shape)
self.assertTrue(k_ref.shape == k2d.shape)
torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2)
def test_ops_equivalence_nope_first(self) -> None:
import aiter as ops
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
(
head_size,
rotary_dim,
max_pos,
base,
is_neox,
dtype,
device,
bs,
seq_len,
num_q,
num_kv,
) = (
128,
64,
2048,
10000,
True,
torch.bfloat16,
"cuda",
1,
16,
2,
2,
)
rope = AiterRotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
positions = torch.arange(seq_len, device=device).repeat(bs)
num_tokens = positions.numel()
q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device)
k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device)
q_ref, k_ref = rope.forward_hip(
positions.clone(), q2d.clone(), k2d.clone(), is_nope_first=True
)
q_sbhd = q2d.view(1, num_tokens, num_q, head_size)
k_sbhd = k2d.view(1, num_tokens, num_kv, head_size)
cos = rope.cos_cache.to(device=device, dtype=dtype)
sin = rope.sin_cache.to(device=device, dtype=dtype)
pos_b_s = positions.view(1, num_tokens)
rotate_style = 0 if is_neox else 1
q_rot = q_sbhd[..., -rotary_dim:]
k_rot = k_sbhd[..., -rotary_dim:]
ops.rope_cached_positions_2c_fwd_inplace(
q_rot,
k_rot,
cos,
sin,
pos_b_s,
rotate_style,
reuse_freqs_front_part=True,
nope_first=True,
)
torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2)
def test_sglang_rotary_embedding_forward_hip_matches_native(self) -> None:
from sglang.srt.layers.rotary_embedding import (
RotaryEmbedding as SglRotaryEmbedding,
)
(
head_size,
rotary_dim,
max_pos,
base,
is_neox,
dtype,
device,
bs,
seq_len,
num_q,
num_kv,
) = (
128,
64,
2048,
10000,
True,
torch.bfloat16,
"cuda",
2,
64,
4,
2,
)
rope = SglRotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
positions = torch.arange(seq_len, device=device).repeat(bs)
q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device)
k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device)
q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone())
q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone())
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
def test_llama3_rotary_embedding_forward_hip_matches_native(self) -> None:
from sglang.srt.layers.rotary_embedding import get_rope as sgl_get_rope
(
head_size,
rotary_dim,
max_pos,
base,
is_neox,
dtype,
device,
bs,
seq_len,
num_q,
num_kv,
) = (
128,
128,
2048,
10000,
True,
torch.bfloat16,
"cuda",
2,
64,
4,
2,
)
rope = sgl_get_rope(
head_size,
rotary_dim,
max_pos,
base,
is_neox,
rope_scaling={
"rope_type": "llama3",
"factor": 1.0,
"low_freq_factor": 1.0,
"high_freq_factor": 1.0,
"original_max_position_embeddings": max_pos,
},
dtype=dtype,
).to(device)
positions = torch.arange(seq_len, device=device).repeat(bs)
q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device)
k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device)
q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone())
q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone())
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
unittest.main()