Files
xc-llm-kunlun/vllm_kunlun/ops/_kunlun_ops.py
2026-03-02 15:49:24 +08:00

793 lines
24 KiB
Python

#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
from typing import Optional
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import cocopod # noqa
import kunlun_ops
logger.info("Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV1"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
@staticmethod
def paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV2"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
kunlun_ops.quick_gelu(
x,
out=out,
)
# Layernorm
@staticmethod
def rms_norm(
out,
x,
weight,
epsilon,
):
"""rms_norm"""
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
x,
residual,
weight,
epsilon,
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
kunlun_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
torch.ops._C.rotary_embedding(
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
assert is_neox_style
kunlun_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
kunlun_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
kunlun_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
):
"""reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32)
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
usual_seq_lod_xpu: torch.Tensor,
usual_seq_lod_cpu: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
"""
if query.dim() == 3:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
kunlun_ops.attention(
q=query,
k_cache=key,
v_cache=value,
out=output,
is_causal=True,
is_prefill=True,
context_seq_lod_cpu=usual_seq_lod_cpu,
context_seq_lod_xpu=usual_seq_lod_xpu,
)
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
kunlun_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
return out
def _dbg(x):
if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
ep_rank: int,
moe_top_k: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks,
global_num_experts,
dtype=torch.int32,
device=hidden_states.device,
)
router_logits = router_logits.to(torch.float)
if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=False,
)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
experts_id = ep_rank * global_num_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
groupgemm1 = cur_token @ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1)
groupgemm2 = up_gate @ w2[i].T
# Add w2 bias
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
.sum(dim=1)
.to(hidden_states.dtype)
)
return ouput
else:
# from vllm.forward_context import get_forward_context
# forward_context = get_forward_context()
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
# prefix = "model.layers.0.linear_attn"
# if attn_metadata is not None:
# attn_metadata = attn_metadata[prefix]
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
if M * moe_top_k < 400:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids, global_num_experts, False, False, hidden_states
)
)
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
topk_ids, global_num_experts, False
)
out = torch.ops.xspeedgate_ops.fused_moe(
hidden_states,
w1,
w2,
normed_score.to(hidden_states.dtype),
sorted_tokens_num_lod,
sorted_tokens_idx,
experts_num_lod,
)
return out.sum(1)
if M * moe_top_k > 768:
moe_expand = torch.empty(
(M * moe_top_k, N),
dtype=hidden_states.dtype,
device=hidden_states.device,
) # [M*top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1,
dtype=torch.int32,
device=hidden_states.device,
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
else:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids,
global_num_experts,
index_have_neg=False,
sort_mode=True,
x=hidden_states,
)
)
y = torch.empty(
M,
moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
if M < 1024:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
)
d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out1 = out1.reshape(-1, out1.shape[-1])
else:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
act="SWISH_GLU",
)
y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
out = torch.empty(
M,
moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C.moe_fc(
x=out1,
weight=w2,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out,
)
dequant_scale = torch.ones(
[M, moe_top_k], dtype=torch.float32, device=out.device
)
output = torch.empty(
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output,
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
top_k: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)
out = torch.zeros(batch * top_k, hidden_size, dtype=x.dtype, device=x.device)
repeat_x = x.repeat_interleave(top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(num_local_experts):
experts_id = ep_rank * num_local_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
out[selected_token] = up_gate @ w2_weight[i].T
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,
q_lora_rank: int,
kv_lora_rank: int,
q_a_proj_w: torch.Tensor,
q_a_layernorm_w: torch.Tensor,
q_b_proj_w: torch.Tensor,
q_proj_w: torch.Tensor,
kv_a_proj_w: torch.Tensor,
kv_a_layernorm_w: torch.Tensor,
kv_b_proj_w: torch.Tensor,
o_proj_w: torch.Tensor,
head_num: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
max_context_len: int,
layernorm_eps: float,
scale: float,
is_causal: bool,
is_context: bool,
mp_size: int,
local_rank: int,
rotary_pos_embedding: torch.Tensor,
pa_block_tables: torch.Tensor,
position: torch.Tensor,
context_lens_cpu: torch.Tensor,
slot_mapping: torch.Tensor,
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
q_a_proj_w,
q_a_layernorm_w,
q_b_proj_w,
q_proj_w,
kv_a_proj_w,
kv_a_layernorm_w,
kv_b_proj_w,
o_proj_w,
head_num,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
max_context_len,
layernorm_eps,
scale,
is_causal,
is_context,
mp_size,
local_rank,
rotary_pos_embedding,
pa_block_tables,
position,
None,
context_lens_cpu,
slot_mapping,
None,
prompt_lods_cpu,
out=output,
k_cache=k_cache,
v_cache=v_cache,
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = kunlun_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
"""
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q,
k,
v,
g,
beta,
scale,
h0_source,
output_final_state,
use_qk_l2norm_in_kernel,
cu_seqlens,
)
return (o, final_state)