# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings Base Class.""" import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import ApplyRotaryEmb # --8<-- [start:rotary_embedding] @CustomOp.register("rotary_embedding") class RotaryEmbeddingBase(CustomOp): """Original rotary positional embedding.""" # --8<-- [end:rotary_embedding] def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, dtype: torch.dtype, init_cache: bool = True, ) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype # TODO(mgoin): disabled for now due to failures # Flashinfer only supports head_size=64, 128, 256, 512. # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 # self.use_flashinfer = (self.enabled() # and dtype in (torch.float16, torch.bfloat16) # and current_platform.is_cuda() # and has_flashinfer() # and self.head_size in [64, 128, 256, 512]) # Check if use_flashinfer is already set if not hasattr(self, "use_flashinfer"): self.use_flashinfer = False self.use_aiter = ( self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled() ) if self.use_aiter: self.rocm_aiter_triton_rotary_embedding = ( rocm_aiter_ops.get_triton_rotary_embedding_op() ) if init_cache: cache = self._compute_cos_sin_cache() if not self.use_flashinfer: cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) self.apply_rotary_emb = ApplyRotaryEmb( is_neox_style=self.is_neox_style, ) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / ( base ** ( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim ) ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor: # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) # is expensive, so avoid calling it if possible cos_sin_cache = self.cos_sin_cache if ( cos_sin_cache.device == query.device and self.cos_sin_cache.dtype == query.dtype ): return cos_sin_cache cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype) # Avoid mutating buffers during torch.compile (cudagraph) tracing. if torch.compiler.is_compiling(): return cos_sin_cache self.cos_sin_cache = cos_sin_cache return cos_sin_cache def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: cos_sin = self.cos_sin_cache[:seqlen] cos, sin = cos_sin.chunk(2, dim=-1) return cos, sin class RotaryEmbedding(RotaryEmbeddingBase): def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, dtype: torch.dtype, init_cache: bool = True, ) -> None: super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, init_cache=init_cache, ) @staticmethod def forward_static( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, rotary_dim: int, cos_sin_cache: torch.Tensor, is_neox_style: bool, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] query_rot = ApplyRotaryEmb.forward_static( query_rot, cos, sin, is_neox_style, ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] key_rot = ApplyRotaryEmb.forward_static( key_rot, cos, sin, is_neox_style, ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_native( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" cos_sin_cache = self._match_cos_sin_cache_dtype(query) return self.forward_static( positions, query, key, self.head_size, self.rotary_dim, cos_sin_cache, self.is_neox_style, ) def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.use_flashinfer: torch.ops.vllm.flashinfer_rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, ) return query, key from vllm import _custom_ops as ops cos_sin_cache = self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. ops.rotary_embedding( positions, query, key, self.head_size, cos_sin_cache, self.is_neox_style, self.rotary_dim, ) return query, key def forward_hip( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.use_aiter: cos_sin_cache = self._match_cos_sin_cache_dtype(query) self.rocm_aiter_triton_rotary_embedding( positions, query, key, self.head_size, cos_sin_cache, self.is_neox_style, ) return query, key return self.forward_cuda(positions, query, key) def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. if key is None: return self.forward_native(positions, query, key) else: from vllm import _custom_ops as ops cos_sin_cache = self._match_cos_sin_cache_dtype(query) ops.rotary_embedding( positions, query, key, self.head_size, cos_sin_cache, self.is_neox_style, ) return query, key def forward_cpu( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: from vllm import _custom_ops as ops cos_sin_cache = self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. ops.rotary_embedding( positions, query, key, self.head_size, cos_sin_cache, self.is_neox_style, ) return query, key def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" s += f", base={self.base}, is_neox_style={self.is_neox_style}" return s