[ROCm] Remove vLLM rope dependency & use AITER impl (#11322)
This commit is contained in:
@@ -124,6 +124,23 @@ class RotaryEmbedding(CustomOp):
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
self._hip_cached_cos: Optional[torch.Tensor] = None
|
||||
self._hip_cached_sin: Optional[torch.Tensor] = None
|
||||
if _use_aiter:
|
||||
half_rotary = cache.shape[-1] // 2
|
||||
cos_cache = (
|
||||
cache[:, :half_rotary]
|
||||
.contiguous()
|
||||
.view(self.max_position_embeddings, 1, 1, half_rotary)
|
||||
)
|
||||
sin_cache = (
|
||||
cache[:, half_rotary:]
|
||||
.contiguous()
|
||||
.view(self.max_position_embeddings, 1, 1, half_rotary)
|
||||
)
|
||||
self.register_buffer("_hip_cos_cache", cos_cache, persistent=False)
|
||||
self.register_buffer("_hip_sin_cache", sin_cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
@@ -184,6 +201,109 @@ class RotaryEmbedding(CustomOp):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional["FusedSetKVBufferArg"] = None,
|
||||
*,
|
||||
is_nope_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if not _use_aiter:
|
||||
return self.forward_native(
|
||||
positions, query, key, offsets, fused_set_kv_buffer_arg
|
||||
)
|
||||
|
||||
if fused_set_kv_buffer_arg is not None:
|
||||
raise NotImplementedError(
|
||||
"fused_set_kv_buffer_arg is not supported for HIP path"
|
||||
)
|
||||
|
||||
import aiter as ops
|
||||
|
||||
if not hasattr(self, "_hip_cos_cache") or not hasattr(self, "_hip_sin_cache"):
|
||||
raise RuntimeError("HIP caches not initialised")
|
||||
|
||||
cos = self._hip_cached_cos
|
||||
sin = self._hip_cached_sin
|
||||
if cos is None or cos.device != query.device or cos.dtype != query.dtype:
|
||||
cos = self._hip_cos_cache.to(query.device, dtype=query.dtype)
|
||||
sin = self._hip_sin_cache.to(query.device, dtype=query.dtype)
|
||||
self._hip_cached_cos = cos
|
||||
self._hip_cached_sin = sin
|
||||
|
||||
rotate_style = 0 if self.is_neox_style else 1
|
||||
|
||||
num_tokens = positions.numel()
|
||||
query_shape = query.shape
|
||||
query = query.view(1, num_tokens, -1, self.head_size)
|
||||
key_shape = key.shape if key is not None else None
|
||||
if key is not None:
|
||||
key = key.view(1, num_tokens, -1, self.head_size)
|
||||
|
||||
positions = positions.view(*query.shape[:2])
|
||||
if offsets is not None:
|
||||
offsets = offsets.view(*query.shape[:2])
|
||||
|
||||
if not is_nope_first:
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim] if key is not None else None
|
||||
else:
|
||||
query_rot = query[..., -self.rotary_dim :]
|
||||
key_rot = key[..., -self.rotary_dim :] if key is not None else None
|
||||
|
||||
if key_rot is None:
|
||||
if offsets is None:
|
||||
ops.rope_cached_positions_fwd_inplace(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
else:
|
||||
ops.rope_cached_positions_offsets_fwd_inplace(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
offsets,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
return query.view(query_shape), None
|
||||
|
||||
if offsets is None:
|
||||
ops.rope_cached_positions_2c_fwd_inplace(
|
||||
query_rot,
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
else:
|
||||
ops.rope_cached_positions_offsets_2c_fwd_inplace(
|
||||
query_rot,
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
offsets,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
|
||||
return query.view(query_shape), key.view(key_shape) if key is not None else None
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user