[ROCm] Remove vLLM rope dependency & use AITER impl (#11322)
This commit is contained in:
@@ -124,6 +124,23 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache: torch.Tensor
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
self._hip_cached_cos: Optional[torch.Tensor] = None
|
||||||
|
self._hip_cached_sin: Optional[torch.Tensor] = None
|
||||||
|
if _use_aiter:
|
||||||
|
half_rotary = cache.shape[-1] // 2
|
||||||
|
cos_cache = (
|
||||||
|
cache[:, :half_rotary]
|
||||||
|
.contiguous()
|
||||||
|
.view(self.max_position_embeddings, 1, 1, half_rotary)
|
||||||
|
)
|
||||||
|
sin_cache = (
|
||||||
|
cache[:, half_rotary:]
|
||||||
|
.contiguous()
|
||||||
|
.view(self.max_position_embeddings, 1, 1, half_rotary)
|
||||||
|
)
|
||||||
|
self.register_buffer("_hip_cos_cache", cos_cache, persistent=False)
|
||||||
|
self.register_buffer("_hip_sin_cache", sin_cache, persistent=False)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
@@ -184,6 +201,109 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor],
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
fused_set_kv_buffer_arg: Optional["FusedSetKVBufferArg"] = None,
|
||||||
|
*,
|
||||||
|
is_nope_first: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if not _use_aiter:
|
||||||
|
return self.forward_native(
|
||||||
|
positions, query, key, offsets, fused_set_kv_buffer_arg
|
||||||
|
)
|
||||||
|
|
||||||
|
if fused_set_kv_buffer_arg is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused_set_kv_buffer_arg is not supported for HIP path"
|
||||||
|
)
|
||||||
|
|
||||||
|
import aiter as ops
|
||||||
|
|
||||||
|
if not hasattr(self, "_hip_cos_cache") or not hasattr(self, "_hip_sin_cache"):
|
||||||
|
raise RuntimeError("HIP caches not initialised")
|
||||||
|
|
||||||
|
cos = self._hip_cached_cos
|
||||||
|
sin = self._hip_cached_sin
|
||||||
|
if cos is None or cos.device != query.device or cos.dtype != query.dtype:
|
||||||
|
cos = self._hip_cos_cache.to(query.device, dtype=query.dtype)
|
||||||
|
sin = self._hip_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
self._hip_cached_cos = cos
|
||||||
|
self._hip_cached_sin = sin
|
||||||
|
|
||||||
|
rotate_style = 0 if self.is_neox_style else 1
|
||||||
|
|
||||||
|
num_tokens = positions.numel()
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(1, num_tokens, -1, self.head_size)
|
||||||
|
key_shape = key.shape if key is not None else None
|
||||||
|
if key is not None:
|
||||||
|
key = key.view(1, num_tokens, -1, self.head_size)
|
||||||
|
|
||||||
|
positions = positions.view(*query.shape[:2])
|
||||||
|
if offsets is not None:
|
||||||
|
offsets = offsets.view(*query.shape[:2])
|
||||||
|
|
||||||
|
if not is_nope_first:
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
key_rot = key[..., : self.rotary_dim] if key is not None else None
|
||||||
|
else:
|
||||||
|
query_rot = query[..., -self.rotary_dim :]
|
||||||
|
key_rot = key[..., -self.rotary_dim :] if key is not None else None
|
||||||
|
|
||||||
|
if key_rot is None:
|
||||||
|
if offsets is None:
|
||||||
|
ops.rope_cached_positions_fwd_inplace(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
positions,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=is_nope_first,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.rope_cached_positions_offsets_fwd_inplace(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
positions,
|
||||||
|
offsets,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=is_nope_first,
|
||||||
|
)
|
||||||
|
return query.view(query_shape), None
|
||||||
|
|
||||||
|
if offsets is None:
|
||||||
|
ops.rope_cached_positions_2c_fwd_inplace(
|
||||||
|
query_rot,
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
positions,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=is_nope_first,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.rope_cached_positions_offsets_2c_fwd_inplace(
|
||||||
|
query_rot,
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
positions,
|
||||||
|
offsets,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=is_nope_first,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.view(query_shape), key.view(key_shape) if key is not None else None
|
||||||
|
|
||||||
def forward_npu(
|
def forward_npu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@@ -111,6 +111,239 @@ class TestRotaryEmbeddingAITer(CustomTestCase):
|
|||||||
with self.subTest(case=case):
|
with self.subTest(case=case):
|
||||||
self._run_case_aiter(*case)
|
self._run_case_aiter(*case)
|
||||||
|
|
||||||
|
def test_ops_equivalence_basic(self) -> None:
|
||||||
|
import aiter as ops
|
||||||
|
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
|
||||||
|
|
||||||
|
(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
base,
|
||||||
|
is_neox,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
bs,
|
||||||
|
seq_len,
|
||||||
|
num_q,
|
||||||
|
num_kv,
|
||||||
|
) = (
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
2048,
|
||||||
|
10000,
|
||||||
|
True,
|
||||||
|
torch.bfloat16,
|
||||||
|
"cuda",
|
||||||
|
2,
|
||||||
|
32,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = AiterRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
positions = torch.arange(seq_len, device=device).repeat(bs)
|
||||||
|
num_tokens = positions.numel()
|
||||||
|
|
||||||
|
q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device)
|
||||||
|
k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope.forward_hip(positions.clone(), q2d.clone(), k2d.clone())
|
||||||
|
|
||||||
|
q_sbhd = q2d.view(1, num_tokens, num_q, head_size)
|
||||||
|
k_sbhd = k2d.view(1, num_tokens, num_kv, head_size)
|
||||||
|
|
||||||
|
cos = rope.cos_cache.to(device=device, dtype=dtype)
|
||||||
|
sin = rope.sin_cache.to(device=device, dtype=dtype)
|
||||||
|
pos_b_s = positions.view(1, num_tokens)
|
||||||
|
rotate_style = 0 if is_neox else 1
|
||||||
|
ops.rope_cached_positions_2c_fwd_inplace(
|
||||||
|
q_sbhd,
|
||||||
|
k_sbhd,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
pos_b_s,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(q_ref.shape == q2d.shape)
|
||||||
|
self.assertTrue(k_ref.shape == k2d.shape)
|
||||||
|
torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_ops_equivalence_nope_first(self) -> None:
|
||||||
|
import aiter as ops
|
||||||
|
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
|
||||||
|
|
||||||
|
(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
base,
|
||||||
|
is_neox,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
bs,
|
||||||
|
seq_len,
|
||||||
|
num_q,
|
||||||
|
num_kv,
|
||||||
|
) = (
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
2048,
|
||||||
|
10000,
|
||||||
|
True,
|
||||||
|
torch.bfloat16,
|
||||||
|
"cuda",
|
||||||
|
1,
|
||||||
|
16,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = AiterRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
positions = torch.arange(seq_len, device=device).repeat(bs)
|
||||||
|
num_tokens = positions.numel()
|
||||||
|
|
||||||
|
q2d = torch.randn(num_tokens, num_q * head_size, dtype=dtype, device=device)
|
||||||
|
k2d = torch.randn(num_tokens, num_kv * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope.forward_hip(
|
||||||
|
positions.clone(), q2d.clone(), k2d.clone(), is_nope_first=True
|
||||||
|
)
|
||||||
|
|
||||||
|
q_sbhd = q2d.view(1, num_tokens, num_q, head_size)
|
||||||
|
k_sbhd = k2d.view(1, num_tokens, num_kv, head_size)
|
||||||
|
|
||||||
|
cos = rope.cos_cache.to(device=device, dtype=dtype)
|
||||||
|
sin = rope.sin_cache.to(device=device, dtype=dtype)
|
||||||
|
pos_b_s = positions.view(1, num_tokens)
|
||||||
|
rotate_style = 0 if is_neox else 1
|
||||||
|
|
||||||
|
q_rot = q_sbhd[..., -rotary_dim:]
|
||||||
|
k_rot = k_sbhd[..., -rotary_dim:]
|
||||||
|
ops.rope_cached_positions_2c_fwd_inplace(
|
||||||
|
q_rot,
|
||||||
|
k_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
pos_b_s,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_ref, q_sbhd.view_as(q2d), atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_sbhd.view_as(k2d), atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_sglang_rotary_embedding_forward_hip_matches_native(self) -> None:
|
||||||
|
from sglang.srt.layers.rotary_embedding import (
|
||||||
|
RotaryEmbedding as SglRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
base,
|
||||||
|
is_neox,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
bs,
|
||||||
|
seq_len,
|
||||||
|
num_q,
|
||||||
|
num_kv,
|
||||||
|
) = (
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
2048,
|
||||||
|
10000,
|
||||||
|
True,
|
||||||
|
torch.bfloat16,
|
||||||
|
"cuda",
|
||||||
|
2,
|
||||||
|
64,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = SglRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
positions = torch.arange(seq_len, device=device).repeat(bs)
|
||||||
|
q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device)
|
||||||
|
k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone())
|
||||||
|
q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_llama3_rotary_embedding_forward_hip_matches_native(self) -> None:
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope as sgl_get_rope
|
||||||
|
|
||||||
|
(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
base,
|
||||||
|
is_neox,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
bs,
|
||||||
|
seq_len,
|
||||||
|
num_q,
|
||||||
|
num_kv,
|
||||||
|
) = (
|
||||||
|
128,
|
||||||
|
128,
|
||||||
|
2048,
|
||||||
|
10000,
|
||||||
|
True,
|
||||||
|
torch.bfloat16,
|
||||||
|
"cuda",
|
||||||
|
2,
|
||||||
|
64,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = sgl_get_rope(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
base,
|
||||||
|
is_neox,
|
||||||
|
rope_scaling={
|
||||||
|
"rope_type": "llama3",
|
||||||
|
"factor": 1.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": max_pos,
|
||||||
|
},
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
positions = torch.arange(seq_len, device=device).repeat(bs)
|
||||||
|
q = torch.randn(bs * seq_len, num_q * head_size, dtype=dtype, device=device)
|
||||||
|
k = torch.randn(bs * seq_len, num_kv * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope.forward_native(positions.clone(), q.clone(), k.clone())
|
||||||
|
q_hip, k_hip = rope.forward_hip(positions.clone(), q.clone(), k.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user