181 lines
5.5 KiB
Python
181 lines
5.5 KiB
Python
#
|
|
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-kunlun project.
|
|
#
|
|
|
|
import torch
|
|
import xspeedgate_ops
|
|
import os
|
|
from vllm.model_executor.layers.rotary_embedding import (
|
|
RotaryEmbedding,
|
|
YaRNScalingRotaryEmbedding,
|
|
DynamicNTKScalingRotaryEmbedding,
|
|
MRotaryEmbedding,
|
|
)
|
|
from typing import Optional, Tuple
|
|
import xtorch_ops
|
|
|
|
|
|
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
|
"""Compute the cos and sin cache."""
|
|
inv_freq = self._compute_inv_freq(self.base)
|
|
if hasattr(self, "scaling_factor"):
|
|
self.max_position_embeddings = int(
|
|
self.max_position_embeddings * self.scaling_factor
|
|
)
|
|
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
|
|
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
cos = freqs.cos()
|
|
sin = freqs.sin()
|
|
if os.getenv("FUSED_QK_ROPE_OP") == "1":
|
|
cache_cos = torch.cat((cos, cos), dim=-1)
|
|
cache_sin = torch.cat((sin, sin), dim=-1)
|
|
# [2, self.max_position_embeddings, self.rotary_dim * 2]
|
|
cache = torch.stack((cache_cos, cache_sin), dim=0).unsqueeze(1)
|
|
else:
|
|
cache = torch.cat((cos, sin), dim=-1).unsqueeze(0).unsqueeze(1)
|
|
return cache
|
|
|
|
|
|
def vllm_kunlun_forward_cuda(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: Optional[torch.Tensor] = None,
|
|
offsets: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""forward_cuda"""
|
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
|
|
|
if (
|
|
self.cos_sin_cache.device != query.device
|
|
or self.cos_sin_cache.dtype != query.dtype
|
|
):
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
|
# are in-place operations that update the query and key tensors.
|
|
if offsets is not None:
|
|
ops.batched_rotary_embedding(
|
|
positions,
|
|
query,
|
|
key,
|
|
self.head_size,
|
|
self.cos_sin_cache,
|
|
self.is_neox_style,
|
|
self.rotary_dim,
|
|
offsets,
|
|
)
|
|
else:
|
|
query, key = ops.rotary_embedding(
|
|
positions,
|
|
query,
|
|
key,
|
|
self.head_size,
|
|
self.cos_sin_cache,
|
|
self.is_neox_style,
|
|
)
|
|
return query, key
|
|
|
|
|
|
def vllm_kunlun_mrope_forward_cuda(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""PyTorch-native implementation equivalent to forward().
|
|
|
|
Args:
|
|
positions:
|
|
[num_tokens,] (text only) or
|
|
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
|
query: [num_tokens, num_heads * head_size]
|
|
key: [num_tokens, num_kv_heads * head_size]
|
|
"""
|
|
|
|
assert positions.ndim == 2
|
|
assert key is not None
|
|
|
|
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
|
query,
|
|
key,
|
|
positions.to(dtype=torch.int32),
|
|
self.cos_sin_cache,
|
|
False, # self.mrope_interleaved,
|
|
self.head_size,
|
|
self.rotary_dim,
|
|
self.mrope_section[0],
|
|
self.mrope_section[1],
|
|
self.mrope_section[2],
|
|
)
|
|
|
|
return query, key
|
|
|
|
|
|
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
|
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
|
if os.getenv("KUNLUN_ENABLE_MULTI_LORA") == "1" or os.getenv("FUSED_QK_ROPE_OP") == "1":
|
|
RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
|
else:
|
|
pass
|
|
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
|
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
|
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
|
|
|
|
|
def Split_Norm_Rope(
|
|
qkv: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
q_norm_weight: torch.Tensor,
|
|
k_norm_weight: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
max_position_embeddings: int,
|
|
q_head_num: int,
|
|
kv_head_num: int,
|
|
head_dim: int,
|
|
partial_rotary_factor: float = 1.0,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
num_tokens = qkv.shape[0]
|
|
rotary_dim = head_dim
|
|
if partial_rotary_factor < 1.0:
|
|
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
|
q_emb_out = torch.empty(
|
|
(num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
|
)
|
|
k_emb_out = torch.empty(
|
|
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
|
)
|
|
v_out = torch.empty(
|
|
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
|
)
|
|
torch.ops._C.split_norm_rope_neox(
|
|
q_emb_out,
|
|
k_emb_out,
|
|
v_out,
|
|
qkv,
|
|
cos_sin_cache,
|
|
q_norm_weight,
|
|
k_norm_weight,
|
|
positions,
|
|
num_tokens,
|
|
max_position_embeddings,
|
|
q_head_num,
|
|
kv_head_num,
|
|
head_dim,
|
|
rotary_dim,
|
|
)
|
|
return q_emb_out, k_emb_out, v_out
|