Files
xc-llm-kunlun/vllm_kunlun/ops/_kunlun_ops.py
2025-12-10 15:52:23 +08:00

582 lines
17 KiB
Python

#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import xtorch_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV1"""
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
@staticmethod
def paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV2"""
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
xtorch_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
xtorch_ops.quick_gelu(
x,
out=out,
)
# Layernorm
@staticmethod
def rms_norm(
out,
x,
weight,
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
x,
residual,
weight,
epsilon,
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
xtorch_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
if not is_neox_style:
if cos_sin_cache.dtype == torch.float16:
cos_sin_cache = cos_sin_cache.to(torch.float32)
positions = positions.to(torch.int)
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
xtorch_ops.rotary_embedding_gptj(
positions, query_x, key_x, head_size, cos_sin_cache
)
query.data = query_x
key.data = key_x
if query_x_dim != query_x.dim():
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
return query, key
# TODO: need opt
if cos_sin_cache.dim() == 4:
max_seq_len = cos_sin_cache.shape[2]
head_dim = cos_sin_cache.shape[3]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
0
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
# Reshape query and key
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
torch.ops._C.rotary_embedding(
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
xtorch_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
xtorch_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
):
"""reshape_and_cache"""
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
usual_seq_lod_xpu: torch.Tensor,
usual_seq_lod_cpu: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
"""
if query.dim() == 3:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
q=query,
k_cache=key,
v_cache=value,
out=output,
is_causal=True,
is_prefill=True,
context_seq_lod_cpu=usual_seq_lod_cpu,
context_seq_lod_xpu=usual_seq_lod_xpu,
)
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
return out
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
top_k: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)
out = torch.zeros(batch * top_k, hidden_size, dtype=x.dtype, device=x.device)
repeat_x = x.repeat_interleave(top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(num_local_experts):
experts_id = ep_rank * num_local_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
out[selected_token] = up_gate @ w2_weight[i].T
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,
q_lora_rank: int,
kv_lora_rank: int,
q_a_proj_w: torch.Tensor,
q_a_layernorm_w: torch.Tensor,
q_b_proj_w: torch.Tensor,
q_proj_w: torch.Tensor,
kv_a_proj_w: torch.Tensor,
kv_a_layernorm_w: torch.Tensor,
kv_b_proj_w: torch.Tensor,
o_proj_w: torch.Tensor,
head_num: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
max_context_len: int,
layernorm_eps: float,
scale: float,
is_causal: bool,
is_context: bool,
mp_size: int,
local_rank: int,
rotary_pos_embedding: torch.Tensor,
pa_block_tables: torch.Tensor,
position: torch.Tensor,
context_lens_cpu: torch.Tensor,
slot_mapping: torch.Tensor,
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
xtorch_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
q_a_proj_w,
q_a_layernorm_w,
q_b_proj_w,
q_proj_w,
kv_a_proj_w,
kv_a_layernorm_w,
kv_b_proj_w,
o_proj_w,
head_num,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
max_context_len,
layernorm_eps,
scale,
is_causal,
is_context,
mp_size,
local_rank,
rotary_pos_embedding,
pa_block_tables,
position,
None,
context_lens_cpu,
slot_mapping,
None,
prompt_lods_cpu,
out=output,
k_cache=k_cache,
v_cache=v_cache,
)
return output