from typing import Tuple, List, Dict, Optional import torch from torch import Tensor import torch._custom_ops @torch._custom_ops.impl_abstract("torch_mlu_ops::attention_project") def attention_project_abstract( input: Tensor, q_weight: Tensor, q_bias: Optional[Tensor], k_weight: Optional[Tensor], k_bias: Optional[Tensor], v_weight: Optional[Tensor], v_bias: Optional[Tensor], norm_weight: Optional[Tensor], norm_bias: Optional[Tensor], residual: Optional[Tensor], out_layout: str, head_size: int, eps: float, alpha: float, beta: float, norm_out: bool, ) -> List[Tensor]: input_view = input if input.dim() == 2: input_view = input.unsqueeze(0) n = input_view.size(0) t = input_view.size(1) hidden_size_q = q_weight.size(0) hidden_size_k = k_weight.size(0) if k_weight is not None else 0 hidden_size_v = v_weight.size(0) if v_weight is not None else 0 head_num_q = hidden_size_q // head_size head_num_k = hidden_size_k // head_size head_num_v = hidden_size_v // head_size out_q = torch.empty(n, t, hidden_size_q, dtype=input_view.dtype, device=input_view.device) out_k = torch.empty(n, t, hidden_size_k, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None out_v = torch.empty(n, t, hidden_size_v, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None if out_layout == "nhtc": out_q = torch.empty(n, head_num_q, t, head_size, dtype=input_view.dtype, device=input_view.device) out_k = torch.empty(n, head_num_k, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None out_v = torch.empty(n, head_num_v, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None out_ln = torch.empty_like(input_view) if norm_out else None res = [out_q] if k_weight is not None: res.append(out_k) if v_weight is not None: res.append(out_v) if norm_out: res.append(out_ln) return res @torch._custom_ops.impl_abstract("torch_mlu_ops::ffn") def ffn_abstract( input: Tensor, up_fc_weight: Tensor, up_fc_bias: Optional[Tensor], down_proj_weight: Tensor, down_proj_bias: Optional[Tensor], gate_up_proj_weight: Optional[Tensor], gate_up_proj_bias: Optional[Tensor], layernorm_weight: Optional[Tensor], layernorm_bias: Optional[Tensor], act_mode: str, residual_is: str, eps: float, alpha: float, beta: float, ) -> Tensor: return torch.empty_like(input) @torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attention") def flash_attention_abstract( q: Tensor, k: Tensor, v: Tensor, out: Tensor, output_lse: Optional[Tensor], cu_seq_lens_q: Optional[Tensor], cu_seq_lens_kv: Optional[Tensor], alibi_slope: Optional[Tensor], attn_bias: Optional[Tensor], k_quant_scale: Optional[Tensor], v_quant_scale: Optional[Tensor], block_tables: Optional[Tensor], max_seq_len_q: int, max_seq_len_kv: int, softmax_scale: float, is_causal: bool, window_size_left: int, window_size_right: int, compute_dtype: str, return_lse: bool, ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::single_query_cached_kv_attn") def single_query_cached_kv_attn_abstract( q_ori: Tensor, k_cache: Tensor, v_cache: Tensor, output: Tensor, block_tables: Tensor, context_lens: Tensor, output_lse: Optional[Tensor], k_cache_quant_scale: Optional[Tensor], v_cache_quant_scale: Optional[Tensor], alibi_slopes: Optional[Tensor], max_contxt_len: int, windows_size_left: int, windows_size_right: int, softmax_scale: float, return_lse: bool, kv_cache_quant_bit_size: int ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::apply_rotary") def apply_rotary_abstract( input: Tensor, sin_cache: Tensor, cos_cache: Tensor, position_ids: Optional[Tensor], cu_seqlens: Optional[Tensor], interleaved: bool, discrete: bool, dynamic_ntk: bool, max_seqlen: int, ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_linear_cache") def reshape_linear_cache_abstract( key: Tensor, value: Optional[Tensor], key_cache: Tensor, value_cache: Optional[Tensor], context_lengths: Tensor, max_context_len: int, packed: bool, context_seq_offset: Optional[Tensor], cache_bs_id: Optional[Tensor], cache_seqlen_offset: Optional[Tensor], ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_paged_cache") def reshape_paged_cache_abstract( k: Tensor, v: Optional[Tensor], k_cache: Tensor, v_cache: Optional[Tensor], slot_mapping: Tensor ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_paged_cache") def quant_to_paged_cache_abstract( k: Tensor, v: Optional[Tensor], k_cache: Tensor, v_cache: Optional[Tensor], k_cache_scale: Tensor, v_cache_scale: Optional[Tensor], slot_mapping: Tensor, ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_paged_cache") def offline_quant_to_paged_cache_abstract( k: Tensor, v: Optional[Tensor], k_cache_scale: Tensor, v_cache_scale: Optional[Tensor], slot_mapping: Tensor, k_cache: Tensor, v_cache: Optional[Tensor], ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_linear_cache") def quant_to_linear_cache_abstract( key: Tensor, value: Optional[Tensor], key_cache: Tensor, value_cache: Optional[Tensor], key_cache_scale: Tensor, value_cache_scale: Optional[Tensor], context_lengths: Tensor, max_context_len: int, packed: bool, context_seq_offset: Optional[Tensor], cache_bs_id: Optional[Tensor], cache_seqlen_offset: Optional[Tensor], quant_bit: int = 8, ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_linear_cache") def offline_quant_to_linear_cache_abstract( key: Tensor, value: Optional[Tensor], key_cache: Tensor, value_cache: Optional[Tensor], key_cache_scale: Tensor, value_cache_scale: Optional[Tensor], context_lengths: Tensor, max_context_len: int, quant_mode: int, packed: bool, context_seq_offset: Optional[Tensor], cache_bs_id: Optional[Tensor], cache_seqlen_offset: Optional[Tensor], ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::swap_blocks") def swap_blocks_abstract( dst: Tensor, src: Tensor, block_mapping: Dict[int, int] ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks") def copy_blocks_abstract( k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]] ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks_out_of_place") def copy_blocks_out_of_place_abstract(k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]) -> (List[Tensor], List[Tensor]): return ([torch.empty_like(k) for k in k_caches], [torch.empty_like(v) for v in v_caches]) @torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul") def quant_matmul_abstract( a_tensor: Tensor, a_scale: Optional[Tensor], a_zero: Optional[Tensor], b_tensor: Tensor, b_scale: Optional[Tensor], b_zero: Optional[Tensor], bias: Optional[Tensor], c_tensor: Optional[Tensor], c_scale: Optional[Tensor], c_zero: Optional[Tensor], gemm_output_scale: Optional[Tensor], gemm_output_zero: Optional[Tensor], data_type: Optional[str], d: Optional[Tensor], quant_algo: str, a_quant_layout: str, b_quant_layout: str, quant_bit_size: int = 8, act_mode: str = "none", use_hp_active: bool = False, act_coef: float = 1.0, alpha: float = 1.0, beta: float = 1.0, trans_a: bool = False, trans_b: bool = True, ) -> Tensor: if data_type is None: output_type = a_tensor.dtype elif data_type == "float": output_type = torch.float32 elif data_type == "bfloat16": output_type = torch.bfloat16 else: output_type = torch.float16 return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul_allreduce") def quant_matmul_allreduce_abstract( cncl_comm, a_tensor: torch.Tensor, a_scale: Optional[torch.Tensor], a_zero: Optional[torch.Tensor], b_tensor: torch.Tensor, b_scale: Optional[torch.Tensor], b_zero: Optional[torch.Tensor], bias: Optional[torch.Tensor], c_tensor: Optional[torch.Tensor], c_scale: Optional[torch.Tensor], c_zero: Optional[torch.Tensor], gemm_output_scale: Optional[torch.Tensor], gemm_output_zero: Optional[torch.Tensor], data_type: Optional[str], d: Optional[torch.Tensor], quant_algo: str, a_quant_layout: str, b_quant_layout: str, quant_bit_size: int = 8, alpha: float = 1.0, beta: float = 1.0, trans_a: bool = False, trans_b: bool = True, block_m: int = 0 ) -> torch.Tensor: output_type = torch.float16 if data_type == "float": output_type = torch.float32 elif data_type == "bfloat16": output_type = torch.bfloat16 return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::active") def active_abstract(input: Tensor, output: Tensor, bias: Optional[Tensor], cusum_token_count: Optional[Tensor], act_mode: str, is_gated: bool, start_expert_id: int = 0, expert_size: int = 0, active_coef: float = 1.0) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::smooth_quant") def smooth_quant_abstract( input: Tensor, input_scale: Tensor, output: Tensor, output_scale: Tensor, input_zero: Optional[Tensor], token_count: Optional[Tensor], gather_index: Optional[Tensor], gather_index_start_position: Optional[Tensor], quant_mode: str, dynamic_quant: bool ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::fused_layernorm") def fused_layernorm_abstract( input: Tensor, output: Tensor, residual: Optional[Tensor], beta: Optional[Tensor], gamma: Optional[Tensor], bias: Optional[Tensor], quant_scale: Optional[Tensor], residual_out: Optional[Tensor], smooth_quant_scale: Optional[Tensor], norm_mode: str, eps: float, store_output_before_norm: bool, dynamic_quant: bool, )-> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::fused_moe") def fused_moe_abstract( hidden_states: Tensor, gating_output: Tensor, w1: Tensor, w2: Tensor, bias1: Optional[Tensor], bias2: Optional[Tensor], residual: Optional[Tensor], input_smooth: Optional[Tensor], act_smooth: Optional[Tensor], w1_scale: Optional[Tensor], w2_scale: Optional[Tensor], topk: int, renormalize: bool, gated: bool, act_mode: str, start_expert_id: int, block_n: int, cncl_comm: int, w1_quant_flag: Optional[List], w2_quant_flag: Optional[List] ) -> Tensor: return torch.empty_like(hidden_states) @torch._custom_ops.impl_abstract("torch_mlu_ops::matmul") def matmul_abstract( a: Tensor, b: Tensor, d: Optional[Tensor], bias: Optional[Tensor], c: Optional[Tensor], data_type: Optional[str], act_mode: str, alpha: float, beta: float, fast_act: bool, approximate: bool, a_scale: float, b_scale: float, trans_a: bool, trans_b: bool ) -> Tensor: m = a.size(1) if trans_a else a.size(0) n = b.size(0) if trans_b else b.size(1) if data_type is None: output_type = a.dtype elif data_type == "float": output_type = torch.float32 elif data_type == "bfloat16": output_type = torch.bfloat16 else: output_type = torch.half return torch.empty(m, n, dtype=output_type, device=a.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::batch_matmul") def batch_matmul_abstract( a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, a_scale: float, b_scale: float, trans_a: bool, trans_b: bool ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::matmul_allreduce") def matmul_allreduce_abstract( cncl_comm, a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None, d: Optional[torch.Tensor] = None, alpha: float = 1.0, beta: float = .0, block_m: int = 0 ) -> Tensor: return torch.empty(a.size(0), b.size(0), dtype=a.dtype, device=a.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::group_gemm") def group_gemm_abstract( a: Tensor, b: Tensor, m_list: Tensor, expand_idx: Optional[Tensor], c: Optional[Tensor], alpha: Optional[Tensor], beta: Optional[Tensor], a_scale: Optional[Tensor], b_scale: Optional[Tensor], bias: Optional[Tensor], data_type: Optional[str], quant_flag: Optional[List], b_offset: Optional[Tensor], max_m: int ) -> Tensor: if data_type is None: output_type = a.dtype elif data_type == "float": output_type = torch.float32 elif data_type == "bfloat16": output_type = torch.bfloat16 else: output_type = torch.half total_m = a.size(0) if expand_idx is None else expand_idx.size(0) return torch.empty(total_m, b.size(1), dtype=output_type, device=a.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::preload") def preload_abstract( weight: Tensor, size: int ) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attn_sq_mm_allreduce") def flash_attn_sq_mm_allreduce_abstract(cncl_comm, q, k, v, cu_seq_lens_q, cu_seq_lens_k, alibi_slope, attn_bias, smooth, weight, weight_scale, bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, compute_dtype, block_seq) -> torch.Tensor: res_q = q.unsqueeze(0) if cu_seq_lens_q is not None else q res_q = res_q.flatten(-2, -1).flatten(0, 1) return torch.empty(res_q.size(0), weight.size(0), dtype=q.dtype, device=q.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::moe_softmax_topk") def moe_softmax_topk_abstract(input, topk, num_expert_group, topk_group, normalize, mask: Optional[torch.Tensor] = None, normed_by: str = "topk_logit") -> Tuple[torch.Tensor, torch.Tensor]: out_shape = list(input.size())[:-1] + [topk] reduce_weight = torch.empty(out_shape, dtype=torch.float32, device=input.device) expert_id = torch.empty(out_shape, dtype=torch.int, device=input.device) return (reduce_weight, expert_id) @torch._custom_ops.impl_abstract("torch_mlu_ops::moe_expand_input") def moe_expand_input_abstract(input: Tensor, gather_idx: Tensor, cusum_token_count: Optional[Tensor] = None, start_expert_id: int = 0, expert_size: int = 0) -> Tensor: return torch.empty(gather_idx.size(0), input.size(-1), dtype=input.dtype, device=input.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::moe_gen_idx") def moe_gen_idx_abstract(expert_id: Tensor, expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: token_num, topk = expert_id.size(0), expert_id.size(1) expand_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device) combine_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device) token_count = torch.empty((expert_num,), dtype=torch.int32, device=expert_id.device) cusum_token_count = torch.empty((expert_num + 1,), dtype=torch.int32, device=expert_id.device) return (expand_idx, combine_idx, token_count, cusum_token_count) @torch._custom_ops.impl_abstract("torch_mlu_ops::moe_combine_result") def moe_combine_result_abstract(input: torch.Tensor, reduce_weight: torch.Tensor, gather_ids: torch.Tensor, residual: Optional[torch.Tensor], cusum_token_count: Optional[torch.Tensor], start_expert_id: int, expert_size: int, bias: Optional[torch.Tensor] = None) -> torch.Tensor: num_tokens, hidden_size, topk = input.size(0), input.size(1), reduce_weight.size(1) num_token = num_tokens // topk return torch.empty(num_token, hidden_size, dtype=input.dtype, device=input.device) @torch._custom_ops.impl_abstract("torch_mlu_ops::fused_rope") def fused_rope_abstract(qkv: torch.Tensor, key_cache_hp: torch.Tensor, value_cache_hp: torch.Tensor, key_cache_lp: Optional[torch.Tensor], value_cache_lp: Optional[torch.Tensor], sin_table: torch.Tensor, cos_table: torch.Tensor, position_ids: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, key_scale_hp: Optional[torch.Tensor], value_scale_hp: Optional[torch.Tensor], key_scale_lp: Optional[torch.Tensor], value_scale_lp: Optional[torch.Tensor], cache_bs_id_hp: Optional[torch.Tensor], cache_seq_offsets_hp: Optional[torch.Tensor], cache_bs_id_lp: Optional[torch.Tensor], cache_seq_offsets_lp: Optional[torch.Tensor], slot_mapping_hp: Optional[torch.Tensor], slot_mapping_lp: Optional[torch.Tensor], eps: float) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::moe_cast_gating") def moe_cast_gating_abstract(input: torch.Tensor, weight: torch.Tensor) ->Tensor: output_shape = input.shape[:-1] + (weight.shape[0],) output = torch.empty(output_shape, dtype=torch.float, device="mlu") return output @torch._custom_ops.impl_abstract("torch_mlu_ops::update_out_and_lse") def update_out_and_lse_abstract(out: torch.Tensor, lse: torch.Tensor, block_out: torch.Tensor, block_lse: torch.Tensor, seq_offsets: Optional[torch.Tensor] = None, cu_seqs: Optional[torch.Tensor] = None, block_cu_seqs: Optional[torch.Tensor] = None) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_linear_cache") def dequant_from_linear_cache_abstract(key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], key_cache_quant_scale: torch.Tensor, value_cache_quant_scale: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, context_seq_offset: Optional[torch.Tensor], cache_bs_id: Optional[torch.Tensor], cache_seq_offset: Optional[torch.Tensor], quant_mode: int = 0, quant_bit: int = 8) -> None: return None @torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_paged_cache") def dequant_from_paged_cache_abstract(key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], key_cache_quant_scale: torch.Tensor, value_cache_quant_scale: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, context_seq_offset: Optional[torch.Tensor], block_tables: torch.Tensor, quant_mode: int = 0, quant_bit: int = 8) -> None: return None