[ROCm] Remove vLLM rope dependency & use AITER impl (#11322)

This commit is contained in:
b8zhong
2025-10-22 19:17:34 -07:00
committed by GitHub
parent 99c92ff24b
commit 4d4feccbb2
2 changed files with 353 additions and 0 deletions

View File

@@ -124,6 +124,23 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self._hip_cached_cos: Optional[torch.Tensor] = None
self._hip_cached_sin: Optional[torch.Tensor] = None
if _use_aiter:
half_rotary = cache.shape[-1] // 2
cos_cache = (
cache[:, :half_rotary]
.contiguous()
.view(self.max_position_embeddings, 1, 1, half_rotary)
)
sin_cache = (
cache[:, half_rotary:]
.contiguous()
.view(self.max_position_embeddings, 1, 1, half_rotary)
)
self.register_buffer("_hip_cos_cache", cos_cache, persistent=False)
self.register_buffer("_hip_sin_cache", sin_cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
@@ -184,6 +201,109 @@ class RotaryEmbedding(CustomOp):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional["FusedSetKVBufferArg"] = None,
*,
is_nope_first: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if not _use_aiter:
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
if fused_set_kv_buffer_arg is not None:
raise NotImplementedError(
"fused_set_kv_buffer_arg is not supported for HIP path"
)
import aiter as ops
if not hasattr(self, "_hip_cos_cache") or not hasattr(self, "_hip_sin_cache"):
raise RuntimeError("HIP caches not initialised")
cos = self._hip_cached_cos
sin = self._hip_cached_sin
if cos is None or cos.device != query.device or cos.dtype != query.dtype:
cos = self._hip_cos_cache.to(query.device, dtype=query.dtype)
sin = self._hip_sin_cache.to(query.device, dtype=query.dtype)
self._hip_cached_cos = cos
self._hip_cached_sin = sin
rotate_style = 0 if self.is_neox_style else 1
num_tokens = positions.numel()
query_shape = query.shape
query = query.view(1, num_tokens, -1, self.head_size)
key_shape = key.shape if key is not None else None
if key is not None:
key = key.view(1, num_tokens, -1, self.head_size)
positions = positions.view(*query.shape[:2])
if offsets is not None:
offsets = offsets.view(*query.shape[:2])
if not is_nope_first:
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim] if key is not None else None
else:
query_rot = query[..., -self.rotary_dim :]
key_rot = key[..., -self.rotary_dim :] if key is not None else None
if key_rot is None:
if offsets is None:
ops.rope_cached_positions_fwd_inplace(
query_rot,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_fwd_inplace(
query_rot,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
return query.view(query_shape), None
if offsets is None:
ops.rope_cached_positions_2c_fwd_inplace(
query_rot,
key_rot,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_2c_fwd_inplace(
query_rot,
key_rot,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
return query.view(query_shape), key.view(key_shape) if key is not None else None
def forward_npu(
self,
positions: torch.Tensor,