# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding from vllm_mlu import _mlu_ops as mlu_ops from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, dtype: torch.dtype, mrope_section: list[int] | None = None, mrope_interleaved: bool = False, # YaRN parameters. *, scaling_factor: float | None = None, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow if self.scaling_factor is not None: # Get n-d magnitude scaling corrected for interpolation self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) else: self.mscale = 1.0 # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get # a larger the cos and sin cache. self.cache_max_position_num = max_position_embeddings * 4 MLURotaryEmbedding.__init__( self, head_size, rotary_dim, self.cache_max_position_num, base, is_neox_style, dtype, ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 def _apply_mrope(self, positions): cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) num_section = len(self.mrope_section) mrope_section = self.mrope_section * 2 def _apply(x): x = torch.cat([ m[i % num_section] for i, m in enumerate(x.split(mrope_section, dim=-1)) ], dim=-1) return x return _apply(cos), _apply(sin) def _apply_interleaved_mrope(self, positions): """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...TT], preserving frequency continuity. """ mrope_section = self.mrope_section cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) def _apply(x): x_t = x[0].clone() x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] offset = self.rotary_dim // 2 x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3] x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3] return x_t return _apply(cos), _apply(sin) def precompute_sin_cos_cache( self, positions: torch.Tensor ): ''' call this function before forward decoder layers precompute sin/cos cache for mrope ''' if positions.ndim == 1: return assert positions.ndim == 2 assert self.mrope_section if self.mrope_interleaved: cos, sin = self._apply_interleaved_mrope(positions) else: cos, sin = self._apply_mrope(positions) self.mrope_cos_cache = cos self.mrope_sin_cache = sin self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device) num_tokens = positions.shape[-1] self.mrope_cu_seq_lens[1] = num_tokens def forward_oot( self, positions: torch.Tensor, x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: assert positions.ndim == 1 or positions.ndim == 2 if positions.ndim == 1: return MLURotaryEmbedding.forward_oot(self, positions, x) assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\ "call precompute_sin_cos_cache first!" num_tokens = positions.shape[-1] x = mlu_ops.rotary_embedding(x, self.mrope_sin_cache, self.mrope_cos_cache, None, self.mrope_cu_seq_lens, not self.is_neox_style, False, False, num_tokens) return x forward = MLURotaryEmbedding.forward