from typing import List, Optional, Tuple import torch import math import triton import triton.language as tl from vllm.logger import init_logger logger = init_logger(__name__) try: import torch_mlu_ops as tmo except ImportError as e: logger.warning("Failed to import from TMO OPS with %r", e) def rotary_embedding( input: torch.Tensor, sin_cache: torch.Tensor, cos_cache: torch.Tensor, position_ids: Optional[torch.Tensor], cu_seqlens: Optional[torch.Tensor], interleaved: bool, discrete: bool, dynamic_ntk: bool, max_seqlen: int, ) -> torch.Tensor: return tmo.apply_rotary( input, sin_cache, cos_cache, position_ids, cu_seqlens, interleaved, discrete, dynamic_ntk, max_seqlen) def fused_rms_norm( x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, bias: torch.Tensor, eps: float, store_output_before_norm: bool, quant_scale: torch.Tensor = None, dynamic_quant: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: return tmo.fused_rms_norm( x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, None, dynamic_quant) def fused_layer_norm( x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, bias: torch.Tensor, eps: float, store_output_before_norm: bool, quant_scale: torch.Tensor = None, dynamic_quant: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: return tmo.fused_layer_norm( x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, None, dynamic_quant) def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, cu_seq_lens_q: torch.Tensor, cu_seq_lens_kv: torch.Tensor, alibi_slope: torch.Tensor, attn_bias: torch.Tensor, max_seq_len_q: int, max_seq_len_kv: int, softmax_scale: float, is_causal: bool, window_size_left: int = -1, window_size_right: int = -1, compute_dtype: torch.dtype = torch.float, return_lse: bool = False, block_tables: torch.Tensor = None, k_cache_quant_scale: torch.Tensor = None, v_cache_quant_scale: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: return tmo.flash_attention( q, k, v, out, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, compute_dtype, return_lse, block_tables, k_cache_quant_scale, v_cache_quant_scale) def split_head_nums(q_head_num, kv_head_num, max_q_head_num): """ 对 q_head_num 进行切分,使得: 1. 切分后的 q_head_num 最大值不超过 max_q_head_num 2. kv_head_num 按 q_head_num 相同份数拆分; 3. 每个切分后的 q_head_num 可以被对应的 kv_head_num 整除; 4. 若 kv_head_num < 1,则调整为 1。 参数: - q_head_num: int, 需要切分的 q_head_num。 - kv_head_num: int, 需要切分的 kv_head_num。 - max_q_head_num: int, 支持切分后最大的 q_head_num 返回: - q_splits: list, 切分后的 q_head_num。 - kv_splits: list, 切分后的 kv_head_num。 """ if q_head_num <= 0 or kv_head_num <= 0: return "q_head_num 和 kv_head_num 必须是正整数!" q_splits = [] kv_splits = [] # 剩余值 remaining_q = q_head_num remaining_kv = kv_head_num while remaining_q > 0: # 尝试切分 q_head_num,最大值不超过 max_q_head_num for q_part in range(min(max_q_head_num, remaining_q), 0, -1): # 确保 q_part 能被分配并且对应的 kv_part >= 1 if remaining_q % q_part == 0: kv_part = max(remaining_kv // (remaining_q // q_part), 1) # 确保 kv_part >= 1 if q_part % kv_part == 0: # 确保 q_part 可以被 kv_part 整除 # 记录切分值 q_splits.append(q_part) kv_splits.append(kv_part) remaining_q -= q_part remaining_kv -= kv_part break else: err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}" raise RuntimeError(err_msg) return q_splits, kv_splits def repeat_elements(input_list, n): """ 将列表的每个成员连续重复 n 次。 参数: - input_list: list,输入的列表。 - n: int,每个元素需要重复的次数。 返回: - list,包含重复元素的新列表。 """ if not isinstance(input_list, list) or not isinstance(n, int) or n < 0: raise ValueError("输入必须是一个列表,并且重复次数 n 必须是大于或等于 0 的整数。") # 使用列表推导式重复每个元素 n 次 return [item for item in input_list for _ in range(n)] def single_query_cached_kv_attn( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, out: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, k_cache_quant_scale: Optional[torch.Tensor], v_cache_quant_scale: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor], max_contxt_len: int, windows_size_left: int, windows_size_right: int, softmax_scale: float, q_head_dim: Optional[int] = 2, kv_head_dim: Optional[int] = 1, seq_q_dim: Optional[int] = 1, max_seq_q_mul_q_divide_kv: Optional[int] = 48, ) -> None: # FIXME(chenxiaobing): TMO only support windows_size_right = -1 yet. windows_size_right = -1 # singleQwithkvCache limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now. # When the limitation is fixed, we should delete the split process. seq_q = q.shape[seq_q_dim] q_head_num = q.shape[q_head_dim] kv_head_num = k_cache.shape[kv_head_dim] q_divide_kv = q_head_num // kv_head_num if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv: tmo.single_query_cached_kv_attn( q, k_cache, v_cache, out, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, alibi_slopes, max_contxt_len, windows_size_left, windows_size_right, softmax_scale) else: max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num) parts_num = len(q_head_num_sizes) q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim) out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim) alibi_slopes_parts = [None] * parts_num if alibi_slopes: alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0) kv_parts_num = parts_num if parts_num > kv_head_num: assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num" kv_parts_num = kv_head_num kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num] if len(kv_head_num_sizes) > 1: k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim) v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim) k_cache_quant_scale_parts = [None] * kv_parts_num v_cache_quant_scale_parts = [None] * kv_parts_num if k_cache_quant_scale: k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim) if v_cache_quant_scale: v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim) else: k_cache_parts = [k_cache] v_cache_parts = [v_cache] k_cache_quant_scale_parts = [k_cache_quant_scale] v_cache_quant_scale_parts = [v_cache_quant_scale] if parts_num > kv_parts_num: repeate_num = parts_num // kv_parts_num k_cache_parts = repeat_elements(k_cache_parts, repeate_num) v_cache_parts = repeat_elements(v_cache_parts, repeate_num) k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num) v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num) for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip( q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts, alibi_slopes_parts): tmo.single_query_cached_kv_attn( q_value, k_cache_value.contiguous(), v_cache_value.contiguous(), out_value, block_tables, context_lens, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value, max_contxt_len, windows_size_left, windows_size_right, softmax_scale) def reshape_linear_cache( key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, packed: bool, context_seq_offset: Optional[torch.Tensor], cache_bs_id: Optional[torch.Tensor], cache_seqlen_offset: Optional[torch.Tensor], ) -> None: tmo.reshape_linear_cache( key, value, key_cache, value_cache, context_lengths, max_context_len, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset) def reshape_paged_cache( k: torch.Tensor, v: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor ) -> None: tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping) def swap_blocks( dst: torch.Tensor, src: torch.Tensor, block_mapping: torch.Tensor ) -> None: # FIXME: Remove this conversion after # tmo.swap_blocks support block_mapping tensor. block_mapping = block_mapping.tolist() block_mapping = {src: dst for src, dst in block_mapping} return tmo.swap_blocks(dst, src, block_mapping) def copy_blocks( k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], block_mapping: torch.Tensor ) -> None: # FIXME: Remove this conversion after # tmo.swap_blocks support block_mapping tensor. block_mapping = block_mapping.tolist() result_dict = {} for row in block_mapping: key = row[0] values = row[1:] if key in result_dict: result_dict[key].extend(values) else: result_dict[key] = values return tmo.copy_blocks(k_caches, v_caches, result_dict) def ffn( input: torch.Tensor, up_fc_weight: torch.Tensor, up_fc_bias: Optional[torch.Tensor], down_proj_weight: torch.Tensor, down_proj_bias: Optional[torch.Tensor], gate_up_proj_weight: Optional[torch.Tensor] = None, gate_up_proj_bias: Optional[torch.Tensor] = None, act_mode: str = "none" ) -> torch.Tensor: return tmo.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, gate_up_proj_weight, gate_up_proj_bias, act_mode) def active( input: torch.Tensor, act_mode: str, is_gated: bool ) -> torch.Tensor: return tmo.active(input, act_mode, is_gated) def fused_moe( hidden_states: torch.Tensor, gating_output: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, bias1: Optional[torch.Tensor], bias2: Optional[torch.Tensor], residual: Optional[torch.Tensor], input_smooth: Optional[torch.Tensor], act_smooth: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], topk: int, renormalize: bool, gated: bool, act_mode: str, start_expert_id: int = 0, block_n: int = 0, cncl_comm: int = 0 ) -> torch.Tensor: return tmo.fused_moe( hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, start_expert_id, block_n, cncl_comm) def matmul( a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None, act_mode: str = 'none', alpha: float = 1.0, beta: float = .0 ) -> torch.Tensor: return tmo.matmul(a, b, bias, c, act_mode, alpha, beta) def weight_only_quant_matmul( a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor = None, bias: torch.Tensor = None, c: torch.Tensor = None, act_mode: str = "none", quant_bit_size: int = 8, alpha: float = 1.0, beta: float = 1.0 ) -> torch.Tensor: return tmo.weight_only_quant_matmul( a, b, scale, zero, bias, c, act_mode, quant_bit_size, alpha, beta) def smooth_quant_matmul( a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor, dtype: torch.dtype, bias: torch.Tensor = None, c: torch.Tensor = None, act_mode: str = "none", alpha: float = 1.0, beta: float = 1.0 ) -> torch.Tensor: return tmo.smooth_quant_matmul( a, a_scale, b, b_scale, dtype, bias, c, act_mode, alpha, beta) def per_token_smooth_quantize( x: torch.Tensor, smooth: torch.Tensor, zero: torch.Tensor = None, token_count: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: return tmo.per_token_smooth_quantize(x, smooth, zero, token_count) def quantize( x: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor = None ) -> torch.Tensor: return tmo.quantize(x, scale, zero) def quant_to_paged_cache( k: torch.Tensor, v: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k_cache_quant_scale: torch.Tensor, v_cache_quant_scale: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: return tmo.quant_to_paged_cache( k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping ) def quant_to_linear_cache( key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], key_cache_quant_scale: torch.Tensor, value_cache_quant_scale: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, packed: bool, context_seq_offset: Optional[torch.Tensor], cache_bs_id: Optional[torch.Tensor], cache_seqlen_offset: Optional[torch.Tensor], ) -> None: return tmo.quant_to_linear_cache( key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, context_lengths, max_context_len, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset, ) def advance_step(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, slot_mapping: torch.Tensor, block_tables: torch.Tensor, TILE_SIZE: int = 64) -> None: """ Advance a step on MLU for existing inputs for a multi-step runner, which will update input_tokens/seq_lens/input_positions/slot_mapping inplace. """ def verify_tensor( name: str, tensor: torch.Tensor, size_0: int, size_1: int, dtype: torch.dtype, ): """ Auxiliary function to check whether input is valid. """ size_0_cond = (size_0 == -1 or tensor.size(0) == size_0) size_1_cond = (size_1 == -1 or tensor.size(1) == size_1) if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype): raise ValueError( f"The input to advance_step is invalid with tensor name = {name}, " f"shape = {tensor.shape}, " f"is_cont = {tensor.is_contiguous()}, " f"type = {tensor.dtype}, " f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}" ) @triton.jit def _triton_advance_step(input_tokens_ptr, sampled_token_ids_ptr, input_positions_ptr, seq_lens_ptr, slot_mapping_ptr, block_tables_ptr, block_tables_stride, num_seqs, num_queries, block_size, TILE_SIZE: tl.constexpr, ): """ The triton implementation of advance step. Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55 """ # Set meta info. pid = tl.program_id(axis=0) offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE) mask = offsets < num_queries # Update input_tokens. sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask) tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask) seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask) next_seq_lens = seq_lens + 1 next_input_pos = next_seq_lens - 1 # Update seq_lens. tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask) # Update input_positions. tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask) # Calculate slot num. block_index = next_input_pos // block_size block_offset = next_input_pos % block_size block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask) slot_num = block_tables * block_size + block_offset # Update slot_mapping. tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask) verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64) verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64) verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32) verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32) verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32) verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32) grid = (math.ceil(num_queries / TILE_SIZE), ) _triton_advance_step[grid](input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, block_tables.stride(0), num_seqs, num_queries, block_size, TILE_SIZE) def preload( weight: torch.Tensor, size: int ) -> None: """ Preload weights of layer. Args: weight (torch.Tensor): Weight to preload。 size (int): Preload size (byte)。 Returns: None """ return tmo.preload(weight, size) def matmul_allreduce( cncl_comm, a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None, alpha: float = 1.0, beta: float = .0, block_m: int = 0 ) -> torch.Tensor: return tmo.matmul_allreduce(cncl_comm=cncl_comm, a=a, b=b, bias=bias, c=c, alpha=alpha, beta=beta, block_m=block_m) def smooth_quant_matmul_allreduce( cncl_comm, a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor, dtype: torch.dtype, bias: torch.Tensor = None, c: torch.Tensor = None, alpha: float = 1.0, beta: float = 1.0, block_m: int = 0): return tmo.smooth_quant_matmul_allreduce( cncl_comm=cncl_comm, a=a, a_scale=a_scale, b=b, b_scale=b_scale, dtype=dtype, bias=bias, c=c, alpha=alpha, beta=beta, block_m=block_m) def quant_matmul_allreduce( cncl_comm, a_tensor: torch.Tensor, a_scale: Optional[torch.Tensor], a_zero: Optional[torch.Tensor], b_tensor: torch.Tensor, b_scale: Optional[torch.Tensor], b_zero: Optional[torch.Tensor], bias: Optional[torch.Tensor], c_tensor: Optional[torch.Tensor], c_scale: Optional[torch.Tensor], c_zero: Optional[torch.Tensor], gemm_output_scale: Optional[torch.Tensor], gemm_output_zero: Optional[torch.Tensor], data_type: Optional[str], quant_algo: str, a_quant_layout: str, b_quant_layout: str, quant_bit_size: int = 8, alpha: float = 1.0, beta: float = 1.0, trans_a: bool = False, trans_b: bool = True, block_m: int = 0 ) -> torch.Tensor: return tmo.quant_matmul_allreduce( cncl_comm=cncl_comm, a_tensor=a_tensor, a_scale=a_scale, a_zero=a_zero, b_tensor=b_tensor, b_scale=b_scale, b_zero=b_zero, bias=bias, c_tensor=c_tensor, c_scale=c_scale, c_zero=c_zero, gemm_output_scale=gemm_output_scale, gemm_output_zero=gemm_output_zero, data_type=data_type, quant_algo=quant_algo, a_quant_layout=a_quant_layout, b_quant_layout=b_quant_layout, quant_bit_size=quant_bit_size, alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, block_m=block_m) def flash_attn_sq_mm_allreduce( cncl_comm: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seq_lens_q: Optional[torch.Tensor], cu_seq_lens_kv: Optional[torch.Tensor], alibi_slope: Optional[torch.Tensor], attn_bias: Optional[torch.Tensor], smooth: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor], max_seq_len_q: int, max_seq_len_kv: int, softmax_scale: float, is_causal: bool, window_size_left: int = -1, window_size_right: int = -1, compute_dtype: torch.dtype = torch.float, block_seq: int = 0) -> torch.Tensor: return tmo.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, smooth, weight, weight_scale, bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, compute_dtype, block_seq) #Moe inner kernels def moe_softmax_topk(input: torch.Tensor, topk: int, normalize: bool = False, num_expert_group: int = -1, topk_group: int = 0) -> Tuple[torch.Tensor]: return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group) def moe_gen_idx(expert_id: torch.Tensor, expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return tmo.moe_gen_idx(expert_id, expert_num) def moe_expand_input(input: torch.Tensor, gather_idx: torch.Tensor, cusum_token_count: Optional[torch.Tensor] = None, start_expert_id: int = 0, expert_size: int = 0) -> torch.Tensor: return tmo.moe_expand_input(input, gather_idx, cusum_token_count, start_expert_id, expert_size) def moe_active(input: torch.Tensor, act_mode: str, is_gated: bool, output: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cusum_token_count: Optional[torch.Tensor] = None, start_expert_id: int = 0, expert_size: int = 0) -> torch.Tensor: return tmo.moe_active(input, act_mode, is_gated, output, bias, cusum_token_count, start_expert_id, expert_size) def group_gemm(a: torch.Tensor, b: torch.Tensor, m_list: torch.Tensor, expand_idx: Optional[torch.Tensor], c: Optional[torch.Tensor], alpha: Optional[torch.Tensor], beta: Optional[torch.Tensor], max_m: int = 0 ) -> torch.Tensor: return tmo.group_gemm(a, b, m_list, expand_idx, c, alpha, beta, max_m) def smooth_quant_group_gemm(a: torch.Tensor, b: torch.Tensor, m_list: torch.Tensor, expand_idx: Optional[torch.Tensor], c: Optional[torch.Tensor], alpha: Optional[torch.Tensor], beta: Optional[torch.Tensor], a_scale: torch.Tensor, b_scale: torch.Tensor, dtype, max_m: int = 0 ) -> torch.Tensor: return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m) def moe_combine_result(input: torch.Tensor, reduce_weight: torch.Tensor, gather_ids: torch.Tensor, residual: Optional[torch.Tensor], cusum_token_count: Optional[torch.Tensor], start_expert_id: int, expert_size: int, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return tmo.moe_combine_result(input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias) def moe_quantize(x: torch.Tensor, smooth: torch.Tensor, zero: Optional[torch.Tensor] = None, token_count: Optional[torch.Tensor] = None, gather_index: Optional[torch.Tensor] = None, gather_index_start_position: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, dynamic_quant: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position, output, output_scale, dynamic_quant)