################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import itertools from typing import Any, Optional, Tuple, Union import torch import torch_br from fastcore.basics import patch_to from transformers import PretrainedConfig import vllm.model_executor.layers.rotary_embedding import vllm.model_executor.models.chatglm import vllm.model_executor.models.deepseek_v2 import vllm_br.envs as br_envs from vllm.logger import logger from vllm.model_executor.layers.rotary_embedding import ( _ROPE_DICT, DeepseekScalingRotaryEmbedding, DualChunkRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, Llama3RotaryEmbedding, Llama4VisionRotaryEmbedding, MRotaryEmbedding, NTKScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm.model_executor.layers.rotary_embedding.common import ( rotate_gptj, rotate_neox, yarn_find_correction_range, yarn_linear_ramp_mask) from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import ( yarn_get_mscale) from vllm.model_executor.layers.rotary_embedding.mrope import ( apply_interleaved_rope) @patch_to(RotaryEmbedding) def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: torch.dtype, op_type: str = "Half", # FIXME: other op type not supported yet ) -> None: logger.info('[Patch] RotaryEmbedding use SUPA RoPE') super(RotaryEmbedding, self).__init__() # type: ignore self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype self.op_type = op_type # FIXME: other op type not supported yet if isinstance(self, MRotaryEmbedding): cache = self._compute_cos_sin_cache() cache = cache.to(dtype) device = torch.cuda.current_device() cache = cache.to(device) self.cos_sin_cache: torch.Tensor # type: ignore self.register_buffer("cos_sin_cache", cache, persistent=False) elif isinstance(self, DeepseekScalingRotaryEmbedding): self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype cache = self._compute_cos_sin_cache() cache = cache.to(dtype) device = torch.supa.current_device() cache = cache.to(device) self.cos_sin_cache: torch.Tensor # type: ignore self.register_buffer("cos_sin_cache", cache, persistent=False) else: sin_cache, cos_cache = self._compute_cos_sin_cache() sin_cache = sin_cache.to(torch.float32) cos_cache = cos_cache.to(torch.float32) device = torch.cuda.current_device() sin_cache = sin_cache.to(device) cos_cache = cos_cache.to(device) self.register_buffer("sin_cache", sin_cache, persistent=False) self.register_buffer("cos_cache", cos_cache, persistent=False) @patch_to(RotaryEmbedding) def _compute_cos_sin_cache(self) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the cos and sin cache.""" with torch.device('cpu'): inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) if isinstance(self, MRotaryEmbedding): cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache else: if self.op_type == "Half" or self.op_type == "TeleChat": freqs = freqs.repeat(1, 2) cos = freqs.cos() sin = freqs.sin() else: cos_freqs = freqs.repeat_interleave(2, dim=-1) cos = cos_freqs.cos() scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1 sin_freqs = cos_freqs * scales.reshape_as(cos_freqs) sin = sin_freqs.sin() return sin, cos @patch_to(RotaryEmbedding) def forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: query_, key_ = torch_br.supa_rope_infer_v2(query, key, self.sin_cache, self.cos_cache, positions, self.head_size, rope_type=self.op_type, rotary_size=self.rotary_dim) return query_, key_ @patch_to(RotaryEmbedding) def enabled(cls) -> bool: return True class SupaDeepseekScalingRotaryEmbedding(RotaryEmbedding): def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, *, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation. self.mscale = float( yarn_get_mscale(self.scaling_factor, float(mscale)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: with torch.device('cpu'): pos_freqs = self.base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float, device="cpu") / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = yarn_find_correction_range( self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - yarn_linear_ramp_mask( low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: with torch.device('cpu'): inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos_freqs = freqs.repeat_interleave(2, dim=-1) cos = (cos_freqs.cos() * self.mscale) scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1 sin_freqs = cos_freqs * scales.reshape_as(cos_freqs) sin = (sin_freqs.sin() * self.mscale) return sin, cos @patch_to(DeepseekScalingRotaryEmbedding) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: with torch.device('cpu'): pos_freqs = self.base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float, device="cpu") / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - yarn_linear_ramp_mask( low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq @patch_to(DeepseekScalingRotaryEmbedding) def _compute_cos_sin_cache(self) -> torch.Tensor: with torch.device('cpu'): inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) sin = (freqs.sin() * self.mscale) cache = torch.cat((cos, sin), dim=-1) return cache @patch_to(DeepseekScalingRotaryEmbedding) def forward_native( self, positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" assert key is not None self._match_cos_sin_cache_dtype(query) query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] cos_sin = self.cos_sin_cache[ torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the # shape [batch_size, seq_len]. cos = cos.repeat(1, 1, 2).unsqueeze(-2) sin = sin.repeat(1, 1, 2).unsqueeze(-2) else: device = torch.supa.current_device() cos = cos.to('cpu') sin = sin.to('cpu') cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) cos = cos.to(device) sin = sin.to(device) rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj device = query_rot.device if query.shape[0] > 1024: query_rot = query_rot.to('cpu') key_rot = key_rot.to('cpu') cos = cos.to('cpu') sin = sin.to('cpu') query_rot = query_rot * cos + rotate_fn(query_rot) * sin key_rot = key_rot * cos + rotate_fn(key_rot) * sin if query.shape[0] > 1024: query_rot = query_rot.to(device) key_rot = key_rot.to(device) if self.rotary_dim < self.head_size: query = torch.cat((query_rot, query_pass), dim=-1) key = torch.cat((key_rot, key_pass), dim=-1) else: query = query_rot key = key_rot return query, key @patch_to(DeepseekScalingRotaryEmbedding) def forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: query, key = self.forward_native(positions, query, key, offsets) return query, key @patch_to(YaRNScalingRotaryEmbedding) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: with torch.device('cpu'): pos_freqs = self.base**( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - yarn_linear_ramp_mask( low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq @patch_to(YaRNScalingRotaryEmbedding) def _compute_cos_sin_cache(self) -> torch.Tensor: with torch.device('cpu'): inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = freqs.repeat(1, 2) cos = freqs.cos() * self.mscale sin = freqs.sin() * self.mscale return sin, cos def dtnamicNTK_compute_cos_sin_cache( self) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the cos and sin cache.""" with torch.device('cpu'): inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) if self.op_type == "Half" or self.op_type == "TeleChat": freqs = freqs.repeat(1, 2) cos = freqs.cos() sin = freqs.sin() else: cos_freqs = freqs.repeat_interleave(2, dim=-1) cos = cos_freqs.cos() scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1 sin_freqs = cos_freqs * scales.reshape_as(cos_freqs) sin = sin_freqs.sin() return sin, cos def dynamicNTKScaling_rope_forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if query.shape[-1] != key.shape[-1]: query_, key_ = torch_br.supa_rope_infer_v2(query, key, self.sin_cache, self.cos_cache, positions, self.head_size, rope_type="MRope") else: query_, key_ = torch_br.supa_rope_infer_v2(query, key, self.sin_cache, self.cos_cache, positions, self.head_size, rope_type=self.op_type) return query_, key_ DynamicNTKScalingRotaryEmbedding._compute_cos_sin_cache = dtnamicNTK_compute_cos_sin_cache DynamicNTKScalingRotaryEmbedding.forward = dynamicNTKScaling_rope_forward def _apply_rotary_emb_torch( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool, ) -> torch.Tensor: cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) if is_neox_style: x1, x2 = torch.chunk(x, 2, dim=-1) else: x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin if is_neox_style: return torch.cat((o1, o2), dim=-1) else: return torch.stack((o1, o2), dim=-1).flatten(-2) def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] cos: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2] is_neox_style: Whether to use the Neox-style or GPT-J-style rotary positional embeddings. """ return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) def forward_MRotaryEmbedding_0_9_2( self, positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). Args: positions: [num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs) query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ assert positions.ndim == 1 or positions.ndim == 2 assert key is not None num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section cos = torch.cat([ m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) ], dim=-1) sin = torch.cat([ m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) ], dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_supa( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: positions: [num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs) query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ if br_envs.VLLM_BR_USE_MROPE_0_9_2: return forward_MRotaryEmbedding_0_9_2(self, positions, query, key) assert positions.ndim == 1 or positions.ndim == 2 data_in_supa = lambda t: str(t.device).startswith('supa') data_in_cpu = lambda t: t.device == torch.device('cpu') if positions.ndim == 2: # use bypass for decode stage if (positions.shape[1] == 1): cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) cos = cos[0] sin = sin[0] else: cos_sin = self.cos_sin_cache[positions.to(torch.int64)] cos, sin = cos_sin.chunk(2, dim=-1) assert self.mrope_section if self.mrope_interleaved: cos = apply_interleaved_rope(cos, self.mrope_section) sin = apply_interleaved_rope(sin, self.mrope_section) else: cos = torch.cat([ m[i] for i, m in enumerate( cos.split(self.mrope_section, dim=-1)) ], dim=-1) sin = torch.cat([ m[i] for i, m in enumerate( sin.split(self.mrope_section, dim=-1)) ], dim=-1) if data_in_supa(query) and data_in_supa(key): sin = sin.supa() if data_in_cpu(sin) else sin cos = cos.supa() if data_in_cpu(cos) else cos positions = positions.supa() if data_in_cpu(positions) else positions query, key = torch_br.supa_rope_infer_v2(query, key, sin.to(torch.float32), cos.to(torch.float32), positions.to(torch.int32), self.head_size, rope_type="MRope") return query, key MRotaryEmbedding.forward = forward_supa def get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, dual_chunk_attention_config: Optional[dict[str, Any]] = None, op_type: str = "Half", ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() } rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None if dual_chunk_attention_config is not None: dual_chunk_attention_tuple = { k: tuple(v) if isinstance(v, list) else v for k, v in dual_chunk_attention_config.items() if k != "sparse_attention_config" } dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) else: dual_chunk_attention_args = None if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = (head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling_args, dual_chunk_attention_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] if dual_chunk_attention_config is not None: extra_kwargs = { k: v for k, v in dual_chunk_attention_config.items() if k in ("chunk_size", "local_size") } rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype, **extra_kwargs) elif not rope_scaling: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype, op_type=op_type) else: scaling_type = rope_scaling["rope_type"] if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype, scaling_factor, low_freq_factor, high_freq_factor, original_max_position) elif scaling_type == "mllama4": rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype=torch.float32, mrope_section=rope_scaling["mrope_section"], mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype, ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor, dtype) elif scaling_type == "ntk": scaling_factor = rope_scaling["factor"] mixed_b = rope_scaling.get('mixed_b', None) rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor, dtype, mixed_b) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor, dtype) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, scaling_factor, dtype, **extra_kwargs) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow", "mscale", "mscale_all_dim") } rotary_emb = DeepseekScalingRotaryEmbedding( head_size, rotary_dim, original_max_position, base, is_neox_style, scaling_factor, dtype, **extra_kwargs) elif scaling_type == "deepseek_yarn_supa": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow", "mscale", "mscale_all_dim") } rotary_emb = SupaDeepseekScalingRotaryEmbedding( head_size, rotary_dim, original_max_position, base, is_neox_style, scaling_factor, dtype, **extra_kwargs) elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale") } rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( head_size, rotary_dim, max_position, original_max_position, base, is_neox_style, dtype, short_factor, long_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb def deepseek_get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> RotaryEmbedding: return get_rope(head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling, dtype, partial_rotary_factor, dual_chunk_attention_config, "DeepSeek") def chatglm2_get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> RotaryEmbedding: return get_rope(head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling, dtype, partial_rotary_factor, dual_chunk_attention_config, "DeepSeek") vllm.model_executor.layers.rotary_embedding.get_rope = get_rope vllm.model_executor.models.deepseek_v2.get_rope = deepseek_get_rope vllm.model_executor.models.chatglm.get_rope = chatglm2_get_rope @patch_to(MRotaryEmbedding) def _glm4v_get_input_positions_tensor( cls, input_tokens: list[int], hf_config: PretrainedConfig, image_grid_thw: Union[list[list[int]], torch.Tensor], video_grid_thw: Union[list[list[int]], torch.Tensor], context_len: int = 0, seq_len: Optional[int] = None, ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value for GLM4V.""" image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] if not (image_grid_thw is None and video_grid_thw is None): if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: if token == video_start_token_id: video_check_flg = True elif token == video_end_token_id: video_check_flg = False if (token == image_token_id) and (video_check_flg is False): input_token_type.append("image") elif (token == image_token_id) and (video_check_flg is True): input_token_type.append("video") else: input_token_type.append("text") input_type_group: list[tuple[str, int, int]] = [] for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): group_list = list(group_iter) start_index = group_list[0][0] end_index = group_list[-1][0] + 1 input_type_group.append((key, start_index, end_index)) video_frame_num = 1 mm_data_idx = 0 for modality_type, start_idx, end_idx in input_type_group: st_idx = llm_pos_ids_list[-1].max() + 1 if len( llm_pos_ids_list) > 0 else 0 if modality_type == "image": t, h, w = ( image_grid_thw[mm_data_idx][0], image_grid_thw[mm_data_idx][1], image_grid_thw[mm_data_idx][2], ) llm_grid_t, llm_grid_h, llm_grid_w = \ t, h // spatial_merge_size, w // spatial_merge_size t_index = torch.arange(llm_grid_t).view(-1, 1).expand( -1, llm_grid_h * llm_grid_w).flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append( torch.stack([t_index, h_index, w_index]) + st_idx) mm_data_idx += 1 elif modality_type == "video": t, h, w = ( video_frame_num, image_grid_thw[mm_data_idx][1], image_grid_thw[mm_data_idx][2], ) llm_grid_t, llm_grid_h, llm_grid_w = \ t, h // spatial_merge_size, w // spatial_merge_size for t_idx in range(llm_grid_t): t_index = torch.tensor(t_idx).view(-1, 1).expand( -1, llm_grid_h * llm_grid_w).flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( 1, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( 1, llm_grid_h, -1).flatten() llm_pos_ids_list.append( torch.stack([t_index, h_index, w_index]) + st_idx) mm_data_idx += 1 video_frame_num += 1 else: text_len = end_idx - start_idx llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) video_frame_num = 1 else: text_len = len(input_tokens) llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta @patch_to(MRotaryEmbedding) def get_input_positions_tensor_for_glm( cls, input_tokens: list[int], hf_config: PretrainedConfig, image_grid_thw: Union[list[list[int]], torch.Tensor], video_grid_thw: Union[list[list[int]], torch.Tensor], second_per_grid_ts: list[float], context_len: int = 0, seq_len: Optional[int] = None, audio_feature_lengths: Optional[torch.Tensor] = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope if thinker_uses_mrope(hf_config): return cls._omni_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, context_len=context_len, seq_len=seq_len, audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) elif "glm4v" in hf_config.model_type: return cls._glm4v_get_input_positions_tensor( cls, input_tokens=input_tokens, hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, context_len=context_len, seq_len=seq_len, ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, context_len=context_len, seq_len=seq_len, )