# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings Base Class.""" import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import ApplyRotaryEmb @CustomOp.register("rotary_embedding") class RotaryEmbeddingBase(CustomOp): """Original rotary positional embedding.""" def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype # TODO(mgoin): disabled for now due to failures # Flashinfer only supports head_size=64, 128, 256, 512. # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 # self.use_flashinfer = (self.enabled() # and dtype in (torch.float16, torch.bfloat16) # and current_platform.is_cuda() # and has_flashinfer() # and self.head_size in [64, 128, 256, 512]) self.use_flashinfer = False cache = self._compute_cos_sin_cache() if not self.use_flashinfer: cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) self.is_rocm_triton_rotary_embed_enabled = ( rocm_aiter_ops.is_triton_rotary_embed_enabled() ) self.apply_rotary_emb = ApplyRotaryEmb( is_neox_style=self.is_neox_style, ) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / ( base ** ( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim ) ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) # is expensive, so avoid calling it if possible if ( self.cos_sin_cache.device != query.device or self.cos_sin_cache.dtype != query.dtype ): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: cos_sin = self.cos_sin_cache[:seqlen] cos, sin = cos_sin.chunk(2, dim=-1) return cos, sin class RotaryEmbedding(RotaryEmbeddingBase): def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, dtype: torch.dtype, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) @staticmethod def forward_static( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, rotary_dim: int, cos_sin_cache: torch.Tensor, is_neox_style: bool, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] query_rot = ApplyRotaryEmb.forward_static( query_rot, cos, sin, is_neox_style, ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] key_rot = ApplyRotaryEmb.forward_static( key_rot, cos, sin, is_neox_style, ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_native( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" return self.forward_static( positions, query, key, self.head_size, self.rotary_dim, self.cos_sin_cache, self.is_neox_style, ) def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.use_flashinfer: torch.ops.vllm.flashinfer_rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, ) return query, key from vllm import _custom_ops as ops self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. ops.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, ) return query, key def forward_hip( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.is_rocm_triton_rotary_embed_enabled: self._match_cos_sin_cache_dtype(query) rocm_aiter_ops.triton_rotary_embed( positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, self.is_neox_style, ) return query, key return self.forward_cuda(positions, query, key) def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: from vllm._ipex_ops import ipex_ops as ops self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. if key is None: # XPU kernel doesn't support key=None so fall back to native impl # TODO(sarckk): add support for optional key in # ipex.llm.functional.rotary_embedding_batched return self.forward_native(positions, query, key) else: ops.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, ) return query, key def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" s += f", base={self.base}, is_neox_style={self.is_neox_style}" return s