Files
xc-llm-kunlun/vllm_kunlun/ops/rotary_embedding.py

190 lines
7.6 KiB
Python
Raw Normal View History

2025-12-10 12:05:39 +08:00
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-12-10 17:51:24 +08:00
# This file is a part of the vllm-ascend project.
2025-12-10 12:05:39 +08:00
#
import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding, YaRNScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
DeepseekScalingRotaryEmbedding)
2025-12-10 12:05:39 +08:00
from typing import Optional, Tuple
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
2025-12-10 17:51:24 +08:00
if hasattr(self, 'scaling_factor'):
self.max_position_embeddings = int(self.max_position_embeddings * self.scaling_factor)
2025-12-10 12:05:39 +08:00
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
2025-12-10 17:51:24 +08:00
#对于glm4-9b-chatrope跑forward_native所以需要cache保持特定的形状这里通过环境变量控制
#对于qwen2.5-vlrope跑mrope也需要cache保持特定的形状
#也就是说跑glm4-9b-chat、qwen2.5-vl需要设置GLM4_CHAT环境变量为1
if os.getenv('ROPE_NATIVE_2D') == "1":
cache = torch.cat((cos, sin), dim=-1)
return cache
if os.getenv('USE_ORI_ROPE') == "0":
2025-12-10 12:05:39 +08:00
cache_cos = torch.cat((cos, cos), dim=-1)
cache_sin = torch.cat((sin, sin), dim=-1)
# [2, self.max_position_embeddings, self.rotary_dim * 2]
cache = torch.stack((cache_cos, cache_sin), dim=0).unsqueeze(1)
else:
cache = torch.cat((cos, sin), dim=-1).unsqueeze(0).unsqueeze(1)
return cache
def vllm_kunlun_forward_cuda(
2025-12-10 17:51:24 +08:00
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""forward_cuda"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
2025-12-10 17:51:24 +08:00
self.cos_sin_cache, self.is_neox_style)
return query, key
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
return x_t
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
def vllm_kunlun_apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
2025-12-10 12:05:39 +08:00
else:
2025-12-10 17:51:24 +08:00
return torch.stack((o1, o2), dim=-1).flatten(-2)
def vllm_kunlun_mrope_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
2025-12-10 12:05:39 +08:00
query,
key,
2025-12-10 17:51:24 +08:00
positions.to(dtype=torch.int32),
2025-12-10 12:05:39 +08:00
self.cos_sin_cache,
2025-12-10 17:51:24 +08:00
self.mrope_interleaved,
2025-12-10 12:05:39 +08:00
self.is_neox_style,
2025-12-10 17:51:24 +08:00
self.head_size,
self.rotary_dim,
self.mrope_section[0],
self.mrope_section[1],
self.mrope_section[2]
2025-12-10 12:05:39 +08:00
)
2025-12-10 17:51:24 +08:00
return query, key
2025-12-10 12:05:39 +08:00
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
2025-12-10 12:05:39 +08:00
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
def Split_Norm_Rope(
qkv: torch.Tensor,
cos_sin_cache: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
max_position_embeddings: int,
q_head_num: int,
kv_head_num: int,
2025-12-10 17:51:24 +08:00
head_dim:int
2025-12-10 12:05:39 +08:00
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
2025-12-10 17:51:24 +08:00
rotary_dim=head_dim
q_emb_out = torch.empty((num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
k_emb_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
v_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
2025-12-10 12:05:39 +08:00
torch.ops._C.split_norm_rope_neox(
2025-12-10 17:51:24 +08:00
q_emb_out,
k_emb_out,
v_out,
qkv,
cos_sin_cache,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_position_embeddings,
q_head_num,
kv_head_num,
head_dim,
rotary_dim,
)
return q_emb_out, k_emb_out, v_out