提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -1,20 +1,3 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
import torch
@@ -29,7 +12,6 @@ logger = init_logger(__name__)
try:
import xtorch_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
@@ -37,15 +19,13 @@ except ImportError as e:
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
"""is per token smooth quant"""
""" is per token smooth quant """
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
@@ -70,9 +50,10 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV1"""
alibi_sqrt=False
):
""" PagedAttentionV1 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
@@ -83,7 +64,7 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
vo_head_dim=128
)
@staticmethod
@@ -112,9 +93,10 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV2"""
alibi_sqrt=False
):
""" PagedAttentionV2 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
@@ -125,28 +107,31 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
vo_head_dim=128
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
xtorch_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
xtorch_ops.quick_gelu(
x,
out=out,
)
)
# Layernorm
@staticmethod
@@ -157,7 +142,9 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
xtorch_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
@staticmethod
def fused_add_rms_norm(
@@ -175,11 +162,16 @@ class KunlunOps:
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
"""
refactor RotaryEmbedding forward function
"""
@@ -196,43 +188,66 @@ class KunlunOps:
key_x = key_x.unsqueeze(0)
xtorch_ops.rotary_embedding_gptj(
positions, query_x, key_x, head_size, cos_sin_cache
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache)
query.data = query_x
key.data = key_x
key.data = key_x
if query_x_dim != query_x.dim():
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
return query, key
# TODO: need opt
if cos_sin_cache.dim() == 4:
max_seq_len = cos_sin_cache.shape[2]
head_dim = cos_sin_cache.shape[3]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
0
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
# Reshape query and key
# 重塑 query key 的形状
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
# query_x = query_x.view(num_tokens, num_heads, head_size)
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
# # 确保形状正确
# assert query_x.shape == (num_tokens, num_heads, head_size), \
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
torch.ops._C.rotary_embedding(
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
is_neox_style)
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
# query.data = query_x
# key.data = key_x
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
positions,
mrope_section,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
"""
refactor RotaryEmbedding forward function
"""
@@ -241,21 +256,35 @@ class KunlunOps:
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
mrope_section)
query.data = query_x
key.data = key_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
xtorch_ops.swap_blocks(src, dst, block_mapping)
def swap_blocks(
src,
dst,
block_mapping):
""" swap_blocks """
xtorch_ops.swap_blocks(
src,
dst,
block_mapping
)
@staticmethod
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
def copy_blocks(
key_caches,
value_caches,
block_mapping):
""" copy_blocks """
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
@@ -273,9 +302,16 @@ class KunlunOps:
value_cache,
slot_mapping,
kv_cache_dtype,
):
"""reshape_and_cache"""
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
):
""" reshape_and_cache """
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping
)
@staticmethod
def multi_query_kv_attention(
@@ -284,7 +320,7 @@ class KunlunOps:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs,
**kargs
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -303,7 +339,7 @@ class KunlunOps:
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
q=query,
@@ -318,132 +354,85 @@ class KunlunOps:
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
def quant_fusedresidual_rmsnorm_op(x,
residual,
weight,
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
def quant_rmsnorm_op(x,
weight,
bias,
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
def smooth_quant_matmul_column_row_kernels(input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
out = torch.empty((input_shape[0] * input_shape[1],
weight_shape[0]),
dtype=torch.float16,
device=weight.device)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
out = torch.empty((input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out)
out = out.view(*output_bs_shape, weight_shape[0])
return out
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
@@ -460,23 +449,23 @@ class KunlunOps:
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
router_logits = x.to(linear_weights.dtype)@linear_weights.T
topk_weights = torch.empty(batch,
top_k,
dtype=router_logits.dtype,
device=router_logits.device)
topk_ids = torch.empty(batch,
top_k,
dtype=torch.int32,
device=router_logits.device)
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -490,22 +479,50 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate)
out[selected_token] = up_gate @ w2_weight[i].T
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
return output
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,
@@ -538,11 +555,10 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
xtorch_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
@@ -579,3 +595,42 @@ class KunlunOps:
v_cache=v_cache,
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = xtorch_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
'''
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
'''
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)