[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -16,33 +16,33 @@
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
from typing import Optional
import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import kunlun_ops
logger.info(f"Load custom ops library success!")
logger.info("Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
""" is per token smooth quant """
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
@@ -67,9 +67,9 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV1 """
alibi_sqrt=False,
):
"""PagedAttentionV1"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
@@ -81,7 +81,7 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
@staticmethod
@@ -110,9 +110,9 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV2 """
alibi_sqrt=False,
):
"""PagedAttentionV2"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
@@ -124,31 +124,28 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
kunlun_ops.quick_gelu(
x,
out=out,
)
)
# Layernorm
@staticmethod
@@ -159,9 +156,7 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
kunlun_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
@@ -179,16 +174,11 @@ class KunlunOps:
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
@@ -196,62 +186,38 @@ class KunlunOps:
key_x = key.contiguous()
torch.ops._C.rotary_embedding(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
is_neox_style)
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions,
mrope_section,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
kunlun_ops.mrotary_embedding_neox(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
mrope_section)
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(
src,
dst,
block_mapping):
""" swap_blocks """
kunlun_ops.swap_blocks(
src,
dst,
block_mapping
)
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
kunlun_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(
key_caches,
value_caches,
block_mapping):
""" copy_blocks """
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
@@ -269,16 +235,10 @@ class KunlunOps:
value_cache,
slot_mapping,
kv_cache_dtype,
):
""" reshape_and_cache """
):
"""reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32)
kunlun_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping
)
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
@@ -287,7 +247,7 @@ class KunlunOps:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -297,16 +257,12 @@ class KunlunOps:
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
kunlun_ops.attention(
q=query,
@@ -321,80 +277,90 @@ class KunlunOps:
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(x,
residual,
weight,
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
kunlun_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(x,
weight,
bias,
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype):
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty((input_shape[0] * input_shape[1],
weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty((input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out)
kunlun_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
@@ -404,6 +370,7 @@ class KunlunOps:
if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
@@ -420,23 +387,24 @@ class KunlunOps:
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
moe_top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
moe_top_k,
dtype=torch.int32,
device=hidden_states.device)
normed_score = torch.empty(
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
num_blocks,
global_num_experts,
dtype=torch.int32,
device=hidden_states.device,
)
router_logits = router_logits.to(torch.float)
if scoring_func == "softmax":
@@ -445,24 +413,27 @@ class KunlunOps:
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
stable=False,
)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
out = torch.zeros(
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
@@ -470,9 +441,13 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
groupgemm1 = cur_token@ w1[i].T
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
groupgemm1 = cur_token @ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
@@ -482,53 +457,129 @@ class KunlunOps:
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
ouput = (
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
.sum(dim=1)
.to(hidden_states.dtype)
)
return ouput
else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
# from vllm.forward_context import get_forward_context
# forward_context = get_forward_context()
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
# prefix = "model.layers.0.linear_attn"
# if attn_metadata is not None:
# attn_metadata = attn_metadata[prefix]
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
if M * moe_top_k < 400:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids, global_num_experts, False, False, hidden_states
)
)
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
topk_ids, global_num_experts, False
)
out = torch.ops.xspeedgate_ops.fused_moe(
hidden_states,
w1,
w2,
normed_score.to(hidden_states.dtype),
sorted_tokens_num_lod,
sorted_tokens_idx,
experts_num_lod,
)
return out.sum(1)
y = torch.empty(M,moe_top_k,
w1.shape[1],
if M * moe_top_k > 768:
moe_expand = torch.empty(
(M * moe_top_k, N),
dtype=hidden_states.dtype,
device=hidden_states.device)
device=hidden_states.device,
) # [M*top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1,
dtype=torch.int32,
device=hidden_states.device,
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
else:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids,
global_num_experts,
index_have_neg=False,
sort_mode=True,
x=hidden_states,
)
)
y = torch.empty(
M,
moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
if M < 1024:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
)
d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out1 = out1.reshape(-1, out1.shape[-1])
else:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
act="SWISH_GLU",
)
y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
out = torch.empty(
M,
moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
@@ -538,8 +589,12 @@ class KunlunOps:
y=out,
)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
dequant_scale = torch.ones(
[M, moe_top_k], dtype=torch.float32, device=out.device
)
output = torch.empty(
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
@@ -547,9 +602,9 @@ class KunlunOps:
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
y=output,
)
return output
@staticmethod
@@ -568,23 +623,23 @@ class KunlunOps:
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype)@linear_weights.T
topk_weights = torch.empty(batch,
top_k,
dtype=router_logits.dtype,
device=router_logits.device)
topk_ids = torch.empty(batch,
top_k,
dtype=torch.int32,
device=router_logits.device)
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -598,11 +653,19 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
out[selected_token] = up_gate @ w2_weight[i].T
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@@ -638,10 +701,11 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
@@ -679,7 +743,6 @@ class KunlunOps:
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
@@ -695,25 +758,34 @@ class KunlunOps:
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
'''
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制
'''
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
"""
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q,
k,
v,
g,
beta,
scale,
h0_source,
output_final_state,
use_qk_l2norm_in_kernel,
cu_seqlens,
)
return (o, final_state)

View File

@@ -9,60 +9,196 @@
# ruff: noqa: E501
import warnings
from typing import Optional
import torch.nn.functional as F
import cocopod # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = -A.transpose(1,2)
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#kernel版
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
chunk_indices = prepare_chunk_indices(
cu_seqlens, 64) if cu_seqlens is not None else None
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
"""
w, u = recompute_w_u_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state

View File

@@ -12,21 +12,21 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
# @triton.autotune(
# configs=[
# triton.Config({
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
(BK, BT), (0, 1))
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
(BK, BV), (1, 0))
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * tl.exp(b_g)[:, None]
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
# b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
_, T, _, _, _ = *q.shape, v.shape[-1]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=64,
BV=32
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
)
return o

View File

@@ -9,28 +9,28 @@
# ruff: noqa: E501
from typing import Optional
import torch
import kunlun_ops
import torch
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(),
k.contiguous(),
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
is_h0_transposed=True,
)
return o, final_state
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:

View File

@@ -10,22 +10,21 @@
import os
from typing import Optional
import kunlun_ops
import torch
from vllm.triton_utils import tl, triton
import kunlun_ops
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd_triton(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
):
out = torch.empty_like(x)
kunlun_ops.l2norm(x, out, eps)
kunlun_ops.l2norm(x, out, eps)
return out

View File

@@ -19,20 +19,21 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from .utils import input_guard
def rms_norm_ref(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True):
def rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
weight)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
eps)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
return out.to(dtype)
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
@@ -149,46 +150,50 @@ def layer_norm_fwd(
# weight = weight.reshape(N)
# print("weight",weight.shape)
# print("x",x.shape)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
y = torch.empty_like(x)
mean, rstd = None, None
import kunlun_ops
kunlun_ops.rms_norm_gated(
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
return y.reshape(x_shape_og)
def layernorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, is_rms_norm)
def layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNormGated(nn.Module):
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return layernorm_fn(x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module):
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)

View File

@@ -11,7 +11,6 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@@ -28,6 +27,7 @@ RESOLUTION = {
torch.complex64: 1.3e-6,
}
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
recompute_w_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
)
return w, u
return w, u

View File

@@ -15,51 +15,52 @@
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union
import kunlun_ops
import torch
from vllm.model_executor.layers import layernorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
def vllm_kunlun_forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
)
return out
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod
def forward_xpu(
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if x.is_contiguous() == False:
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if x.dim() == 3:
x_shape = x.shape
x = x.view(-1, x.size(-1))
if residual is not None:
torch.ops._C.add_rmsnorm(
out = torch.empty_like(x)
out_residual = torch.empty_like(residual)
torch.ops._C.gemma_add_rmsnorm(
x,
residual,
residual_output=residual,
weight=weight+1,
residual_output=out_residual,
weight=weight,
eps=variance_epsilon,
output=x
output=out,
)
else:
out = torch.empty_like(x)
torch.ops._C.gemma_rmsnorm(
x,
weight,
out,
variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
weight+1,
out,
variance_epsilon,
)
return out
if x.dim() == 3:
x = x.view(x_shape)
if out is not None:
out = out.view(x_shape)
if residual is not None:
return out, out_residual
else:
return out
def forward_cuda(
self,
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
self.forward_static = self.forward_xpu # only use in cudagraph
self.forward_static = self.forward_xpu # only use in cudagraph
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static, backend="aot_eager")
self.forward_static, backend="aot_eager"
)
self._is_compiled = True
return self.forward_native(x, residual)
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm

File diff suppressed because it is too large Load Diff