"""kunlun custom op entry""" import torch_xmlir import torch import os from typing import Optional, List, Dict import vllm.envs as envs import os import ctypes from vllm.logger import init_logger logger = init_logger(__name__) try: import xtorch_ops logger.info(f"Load custom ops library success!") except ImportError as e: logger.warning("Import error msg: %s", e.msg) _per_token_smooth_quant = True def is_per_token_smooth_quant(): """ is per token smooth quant """ return _per_token_smooth_quant class KunlunOps: """KunlunOps""" # Attention ops @staticmethod def paged_attention_v1( output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, context_lens_cpu, is_context, block_size, max_context_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, alibi_sqrt=False ): """ PagedAttentionV1 """ # block_size = value_cache.shape[2] xtorch_ops.paged_attention( x=query, k_cache=key_cache, v_cache=value_cache, block_tables=block_tables, context_lens_cpu=context_lens_cpu, context_lens_xpu=context_lens, is_context=is_context, is_causal=True, out=output, vo_head_dim=128 ) @staticmethod def paged_attention_v2( output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, context_lens_cpu, is_context, block_size, max_context_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, alibi_sqrt=False ): """ PagedAttentionV2 """ # block_size = value_cache.shape[2] xtorch_ops.paged_attention( x=query, k_cache=key_cache, v_cache=value_cache, block_tables=block_tables, context_lens_cpu=context_lens_cpu, context_lens_xpu=context_lens, is_context=is_context, is_causal=True, out=output, vo_head_dim=128 ) # Activation ops @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor): """ silu and mul """ xtorch_ops.silu_and_mul( x, axis=-1, turn=True, out=out, ) # Activation ops @staticmethod def quick_gelu(out: torch.Tensor, x: torch.Tensor): """ quick gelu """ xtorch_ops.quick_gelu( x, out=out, ) # Layernorm @staticmethod def rms_norm( out, x, weight, epsilon, ): """rms_norm""" xtorch_ops.rmsnorm( x, weight.to(torch.float32), epsilon, out=out ) @staticmethod def fused_add_rms_norm( x, residual, weight, epsilon, ): """fused_add_rms_norm""" output = torch.empty_like(x) xtorch_ops.add_rmsnorm( x, residual, weight.to(torch.float32), epsilon, out=output ) fused_input = x + residual residual.copy_(fused_input, non_blocking=True) x.copy_(output) # Rotary embedding @staticmethod def rotary_embedding( positions, query, key, head_size, cos_sin_cache, is_neox_style): """ refactor RotaryEmbedding forward function """ query_x = query.contiguous() key_x = key.contiguous() query_x_dim = query_x.dim() if not is_neox_style: if cos_sin_cache.dtype == torch.float16: cos_sin_cache = cos_sin_cache.to(torch.float32) positions = positions.to(torch.int) if positions.dim() == 1: positions = positions.unsqueeze(0) query_x = query_x.unsqueeze(0) key_x = key_x.unsqueeze(0) xtorch_ops.rotary_embedding_gptj( positions, query_x, key_x, head_size, cos_sin_cache) query.data = query_x key.data = key_x if query_x_dim != query_x.dim(): query_x = query_x.unsqueeze(0) key_x = key_x.unsqueeze(0) return query, key # TODO: need opt if cos_sin_cache.dim() == 4: max_seq_len = cos_sin_cache.shape[2] head_dim = cos_sin_cache.shape[3] cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D] cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim) # 重塑 query 和 key 的形状 num_tokens = query_x.shape[0] num_heads = query_x.shape[1] // head_size num_kv_heads = key_x.shape[1] // head_size # # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size] # query_x = query_x.view(num_tokens, num_heads, head_size) # # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size] # key_x = key_x.view(num_tokens, num_kv_heads, head_size) # # 确保形状正确 # assert query_x.shape == (num_tokens, num_heads, head_size), \ # f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}" # assert key_x.shape == (num_tokens, num_kv_heads, head_size), \ # f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}" torch.ops._C.rotary_embedding( positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style) query_x = query_x.view(num_tokens, num_heads * head_size) key_x = key_x.view(num_tokens, num_kv_heads * head_size) # query.data = query_x # key.data = key_x return query_x, key_x # Rotary embedding @staticmethod def mrotary_embedding( positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style): """ refactor RotaryEmbedding forward function """ query_x = query.contiguous() key_x = key.contiguous() query_x_dim = query_x.dim() assert is_neox_style xtorch_ops.mrotary_embedding_neox( positions, query_x, key_x, head_size, cos_sin_cache, mrope_section) query.data = query_x key.data = key_x return query, key @staticmethod def swap_blocks( src, dst, block_mapping): """ swap_blocks """ xtorch_ops.swap_blocks( src, dst, block_mapping ) @staticmethod def copy_blocks( key_caches, value_caches, block_mapping): """ copy_blocks """ for i in range(len(key_caches)): key_caches[i] = key_caches[i].contiguous() value_caches[i] = value_caches[i].contiguous() xtorch_ops.copy_blocks( key_caches, value_caches, block_mapping, ) @staticmethod def reshape_and_cache( key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, ): """ reshape_and_cache """ # slot_mapping_cast = slot_mapping.to(torch.int32) xtorch_ops.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping ) @staticmethod def multi_query_kv_attention( usual_seq_lod_xpu: torch.Tensor, usual_seq_lod_cpu: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kargs ) -> torch.Tensor: """ query: shape = [num_prompt_tokens, num_heads, head_size] """ if query.dim() == 3: query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) output = torch.empty_like(query) alibi_slopes = kargs.get("alibi_slopes", None) mask = kargs.get("mask", None) is_causal = kargs.get("is_causal", True) is_lvsl = kargs.get("is_lvsl", True) B, T, Qh, Hd = query.shape KVh = key.size(2) if KVh != Qh: repeat = Qh // KVh key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd] value = value.repeat_interleave(repeat, dim=2) xtorch_ops.attention( q=query, k_cache=key, v_cache=value, out=output, is_causal=True, is_prefill=True, context_seq_lod_cpu=usual_seq_lod_cpu, context_seq_lod_xpu=usual_seq_lod_xpu, ) return output @staticmethod def quant_fusedresidual_rmsnorm_op(x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1): """Quantized fused residual layer normalization""" out = torch.empty_like(x, dtype=torch.int8) if is_per_token_smooth_quant(): out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps, out=out, out_scale=out_scale , residual_tensor=residual) if residual is None: return out, out_scale return out, out_scale, residual @staticmethod def quant_rmsnorm_op(x, weight, bias, scale_to_int, eps, dyn_scale : bool, type: int = 1): """Quantized RMSNorm""" out = torch.empty_like(x, dtype=torch.int8) if is_per_token_smooth_quant(): out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) else: out_scale = torch.empty(12, device=x.device, dtype=torch.float) xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale) return out, out_scale @staticmethod def smooth_quant_matmul_column_row_kernels(input_tensor, weight, smoother, input_scale, weight_scale, perTokenScaling, perChannelScaling, otype): """smooth_quant_matmul_column_row_kernels""" input_shape = input_tensor.shape weight_shape = weight.shape if input_tensor.dim() == 3: input_tensor = input_tensor.reshape(-1, input_shape[-1]) out = torch.empty((input_shape[0] * input_shape[1], weight_shape[0]), dtype=torch.float16, device=weight.device) output_bs_shape = [input_shape[0], input_shape[1]] elif input_tensor.dim() == 2: out = torch.empty((input_shape[0], weight_shape[0]), dtype=torch.float16, device=weight.device) output_bs_shape = [-1] xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor, weight, smoother, input_scale, weight_scale, perTokenScaling, perChannelScaling, out=out) out = out.view(*output_bs_shape, weight_shape[0]) return out @staticmethod def fused_moe_ep( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, gating_output: torch.Tensor, linear_weights: torch.Tensor, ep_rank: int, top_k: int, renormalize: bool, inplace: bool = False, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = hidden_states batch, hidden_size = x.shape num_local_experts, up_gate_size, _ = w13_weight.shape router_logits = x.to(linear_weights.dtype)@linear_weights.T topk_weights = torch.empty(batch, top_k, dtype=router_logits.dtype, device=router_logits.device) topk_ids = torch.empty(batch, top_k, dtype=torch.int32, device=router_logits.device) block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device) torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static) if renormalize: topk_weights = topk_weights / topk_weights.sum(1, keepdim=True) topk_weights = topk_weights.to(x.dtype) out = torch.zeros(batch * top_k, hidden_size, dtype=x.dtype, device=x.device) repeat_x = x.repeat_interleave(top_k, dim=0) topk_ids_flat = topk_ids.flatten() for i in range(num_local_experts): experts_id = ep_rank * num_local_experts + i selected_token = topk_ids_flat == experts_id if selected_token.sum(): cur_token = repeat_x[selected_token] up_gate = torch.empty(selected_token.sum(), up_gate_size//2, dtype=cur_token.dtype, device=cur_token.device) torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate) out[selected_token] = up_gate @ w2_weight[i].T output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype) return output @staticmethod def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, gating_output: torch.Tensor, linear_weights: torch.Tensor, topk: int, renormalize: bool, inplace: bool = False, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """fused_moe""" output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) expert_num = linear_weights.shape[0] torch.ops._C.moe_ffn_block( x=hidden_states, gate_w=linear_weights, inter_w=w1, output_w=w2, expert_num=expert_num, moe_top_k=topk, topk_group=topk_group, renormalize=renormalize, use_grouped_topk=use_grouped_topk, expert_group_num=num_expert_group, out=output, ) return output @staticmethod def fused_multi_head_latent_page_attention( hidden_states: torch.Tensor, q_lora_rank: int, kv_lora_rank: int, q_a_proj_w: torch.Tensor, q_a_layernorm_w: torch.Tensor, q_b_proj_w: torch.Tensor, q_proj_w: torch.Tensor, kv_a_proj_w: torch.Tensor, kv_a_layernorm_w: torch.Tensor, kv_b_proj_w: torch.Tensor, o_proj_w: torch.Tensor, head_num: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, max_context_len: int, layernorm_eps: float, scale: float, is_causal: bool, is_context: bool, mp_size: int, local_rank: int, rotary_pos_embedding: torch.Tensor, pa_block_tables: torch.Tensor, position: torch.Tensor, context_lens_cpu: torch.Tensor, slot_mapping: torch.Tensor, prompt_lods_cpu: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, ) -> torch.Tensor: """mla pa block""" output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) xtorch_ops.xft_multi_head_latent_page_attention_block( hidden_states, q_lora_rank, kv_lora_rank, q_a_proj_w, q_a_layernorm_w, q_b_proj_w, q_proj_w, kv_a_proj_w, kv_a_layernorm_w, kv_b_proj_w, o_proj_w, head_num, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, max_context_len, layernorm_eps, scale, is_causal, is_context, mp_size, local_rank, rotary_pos_embedding, pa_block_tables, position, None, context_lens_cpu, slot_mapping, None, prompt_lods_cpu, out=output, k_cache=k_cache, v_cache=v_cache, ) return output def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, ) -> torch.Tensor: """fused_gdn_gating""" output = xtorch_ops.fused_gdn_gating( A_log, a, dt_bias, ) return output def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, h0_source: torch.Tensor, output_final_state: bool, use_qk_l2norm_in_kernel: bool, cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]: ''' Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起 1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。 2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。 ''' o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd( q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens) return (o, final_state)