Files
xc-llm-kunlun/vllm_kunlun/ops/rotary_embedding.py
baoqian426 ee0f50e68f [Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2

* fix gpt_oss

* update readme

* update readme

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
2026-01-05 22:55:35 +08:00

190 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding, YaRNScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
DeepseekScalingRotaryEmbedding)
from typing import Optional, Tuple
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
if hasattr(self, 'scaling_factor'):
self.max_position_embeddings = int(self.max_position_embeddings * self.scaling_factor)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
#对于glm4-9b-chatrope跑forward_native所以需要cache保持特定的形状这里通过环境变量控制
#对于qwen2.5-vlrope跑mrope也需要cache保持特定的形状
#也就是说跑glm4-9b-chat、qwen2.5-vl需要设置GLM4_CHAT环境变量为1
if os.getenv('ROPE_NATIVE_2D') == "1":
cache = torch.cat((cos, sin), dim=-1)
return cache
if os.getenv('USE_ORI_ROPE') == "0":
cache_cos = torch.cat((cos, cos), dim=-1)
cache_sin = torch.cat((sin, sin), dim=-1)
# [2, self.max_position_embeddings, self.rotary_dim * 2]
cache = torch.stack((cache_cos, cache_sin), dim=0).unsqueeze(1)
else:
cache = torch.cat((cos, sin), dim=-1).unsqueeze(0).unsqueeze(1)
return cache
def vllm_kunlun_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""forward_cuda"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
return x_t
def vllm_kunlun_apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def vllm_kunlun_mrope_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
query,
key,
positions.to(dtype=torch.int32),
self.cos_sin_cache,
self.mrope_interleaved,
self.is_neox_style,
self.head_size,
self.rotary_dim,
self.mrope_section[0],
self.mrope_section[1],
self.mrope_section[2]
)
return query, key
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
def Split_Norm_Rope(
qkv: torch.Tensor,
cos_sin_cache: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
max_position_embeddings: int,
q_head_num: int,
kv_head_num: int,
head_dim:int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
rotary_dim=head_dim
q_emb_out = torch.empty((num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
k_emb_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
v_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
torch.ops._C.split_norm_rope_neox(
q_emb_out,
k_emb_out,
v_out,
qkv,
cos_sin_cache,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_position_embeddings,
q_head_num,
kv_head_num,
head_dim,
rotary_dim,
)
return q_emb_out, k_emb_out, v_out