from contextlib import contextmanager from typing import List, Optional, Tuple, Union import torch import math import triton import triton.language as tl from vllm.logger import init_logger logger = init_logger(__name__) try: import torch_mlu_ops as tmo import torch_mlu_ops.triton_ops as triton_ops except ImportError as e: logger.warning("Failed to import from TMO OPS with %r", e) from vllm.distributed import ( get_ep_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, get_data_parallel_group_world_size, get_tp_group, get_tp_world_group, get_dp_group, get_data_parallel_group_rank, get_tp_world_world_size, get_tp_world_rank, get_parallel_rank_with_group, ) @triton.jit def _triton_advance_step(input_tokens_ptr, sampled_token_ids_ptr, input_positions_ptr, seq_lens_ptr, slot_mapping_ptr, block_tables_ptr, block_tables_stride, num_seqs, num_queries, block_size, TILE_SIZE: tl.constexpr, ): """ The triton implementation of advance step. Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55 """ # Set meta info. pid = tl.program_id(axis=0) offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE) mask = offsets < num_queries # Update input_tokens. sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask) tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask) seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask) next_seq_lens = seq_lens + 1 next_input_pos = next_seq_lens - 1 # Update seq_lens. tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask) # Update input_positions. tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask) # Calculate slot num. block_index = next_input_pos // block_size block_offset = next_input_pos % block_size block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask) slot_num = block_tables * block_size + block_offset # Update slot_mapping. tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask) def rotary_embedding( input: torch.Tensor, sin_cache: torch.Tensor, cos_cache: torch.Tensor, position_ids: Optional[torch.Tensor], cu_seqlens: Optional[torch.Tensor], interleaved: bool, discrete: bool, dynamic_ntk: bool, max_seqlen: int, ) -> torch.Tensor: return tmo.apply_rotary( input, sin_cache, cos_cache, position_ids, cu_seqlens, interleaved, discrete, dynamic_ntk, max_seqlen) def fused_rms_norm( x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, bias: torch.Tensor, eps: float, store_output_before_norm: bool, quant_scale: torch.Tensor = None, dynamic_quant: bool = False, out: torch.Tensor = None, ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: return tmo.fused_rms_norm( x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant) def fused_layer_norm( x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, bias: torch.Tensor, eps: float, store_output_before_norm: bool, quant_scale: torch.Tensor = None, dynamic_quant: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: return tmo.fused_layer_norm( x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, None, dynamic_quant) def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor], cu_seq_lens_q: Optional[torch.Tensor], cu_seq_lens_kv: Optional[torch.Tensor], alibi_slope: Optional[torch.Tensor], attn_bias: Optional[torch.Tensor], max_seq_len_q: int, max_seq_len_kv: int, softmax_scale: float, is_causal: bool, window_size_left: int = -1, window_size_right: int = -1, compute_dtype: torch.dtype = torch.float, return_lse: bool = False, block_tables: Optional[torch.Tensor] = None, out_quant_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, q_quant_dtype: Optional[torch.dtype] = None, k_quant_dtype: Optional[torch.dtype] = None, v_quant_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if v is None: v = k return tmo.flash_attention( q, k, v, out, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, compute_dtype, return_lse, block_tables, k_quant_scale, v_quant_scale, q_quant_scale, out_quant_scale, out_dtype) def split_head_nums(q_head_num, kv_head_num, max_q_head_num): """ Split q_head_num such that: 1. The maximum value of the split q_head_num does not exceed max_q_head_num. 2. kv_head_num is split into the same number of parts as q_head_num. 3. Each split q_head_num can be evenly divided by the corresponding kv_head_num. 4. If kv_head_num < 1, it is adjusted to 1. Parameters: - q_head_num: int, the q_head_num to be split. - kv_head_num: int, the kv_head_num to be split. - max_q_head_num: int, the maximum supported q_head_num after splitting. Returns: - q_splits: list, the split q_head_num. - kv_splits: list, the split kv_head_num. """ if q_head_num <= 0 or kv_head_num <= 0: return "q_head_num and kv_head_num must be positive integers!" q_splits = [] kv_splits = [] # Residual value remaining_q = q_head_num remaining_kv = kv_head_num while remaining_q > 0: # Attempt to split q_head_num such that the maximum value does not exceed max_q_head_num. for q_part in range(min(max_q_head_num, remaining_q), 0, -1): # Ensure that q_part can be allocated and the corresponding kv_part is greater than or equal to 1. if remaining_q % q_part == 0: # Ensure that kv_part is greater than or equal to 1. kv_part = max(remaining_kv // (remaining_q // q_part), 1) # Ensure that q_part is divisible by kv_part. if q_part % kv_part == 0: # Record the split values. q_splits.append(q_part) kv_splits.append(kv_part) remaining_q -= q_part remaining_kv -= kv_part break else: err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}" raise RuntimeError(err_msg) return q_splits, kv_splits def repeat_elements(input_list, n): """ Repeat each element in the list n times consecutively. Parameters: - input_list: list, the input list. - n: int, the number of times each element should be repeated. Returns: - list, a new list containing the repeated elements. """ if not isinstance(input_list, list) or not isinstance(n, int) or n < 0: raise ValueError("The input must be a list, and the repetition count n must be an integer greater than or equal to 0.") # Repeat each element n times using a list comprehension. return [item for item in input_list for _ in range(n)] def single_query_cached_kv_attn( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, out: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, k_cache_quant_scale: Optional[torch.Tensor], v_cache_quant_scale: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor], max_contxt_len: int, windows_size_left: int, windows_size_right: int, softmax_scale: float, return_lse: bool = False, q_head_dim: Optional[int] = 2, kv_head_dim: Optional[int] = 1, seq_q_dim: Optional[int] = 1, max_seq_q_mul_q_divide_kv: Optional[int] = 128, head_size_v: Optional[int] = -1, compute_dtype: Optional[torch.dtype] = torch.float32, q_quant_scale: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, q_quant_dtype: Optional[torch.dtype] = None, q_scale_dtype: Optional[torch.dtype] = None, learnable_sink: Optional[torch.Tensor] = None, ) -> None: windows_size_right = -1 seq_q = q.shape[seq_q_dim] if q_quant_dtype is not None and q.dtype != q_quant_dtype and q_quant_scale is None: q, q_quant_scale = tmo.scaled_quantize(q.contiguous(), quant_type=q_quant_dtype, quant_mode="dynamic_per_token") if k_cache is not None and k_cache.dtype == torch.uint8: k_cache = k_cache.view(torch.float8_e4m3fn) if v_cache is not None and v_cache.dtype == torch.uint8: v_cache = v_cache.view(torch.float8_e4m3fn) if k_cache is not None and k_cache.dtype == torch.bfloat16: max_seq_q_mul_q_divide_kv = 256 # single_query_cached_kv_attn limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now, # and this limitation only applies when using kv8 or floating point computation. # When the limitation is fixed, we should delete the split process. q_head_num = q.shape[q_head_dim] kv_head_num = k_cache.shape[kv_head_dim] q_divide_kv = q_head_num // kv_head_num if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv or q_quant_scale is not None: tmo.single_query_cached_kv_attn( q, k_cache, v_cache, out, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, alibi_slopes, max_contxt_len, windows_size_left, windows_size_right, softmax_scale, return_lse, q_quant_scale=q_quant_scale, head_size_v=head_size_v, compute_dtype=compute_dtype, mask=mask, learnable_sink=learnable_sink) else: max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num) parts_num = len(q_head_num_sizes) q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim) out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim) alibi_slopes_parts = [None] * parts_num if alibi_slopes: alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0) kv_parts_num = parts_num if parts_num > kv_head_num: assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num" kv_parts_num = kv_head_num kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num] if len(kv_head_num_sizes) > 1: k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim) v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim) k_cache_quant_scale_parts = [None] * kv_parts_num v_cache_quant_scale_parts = [None] * kv_parts_num if k_cache_quant_scale: k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim) if v_cache_quant_scale: v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim) else: k_cache_parts = [k_cache] v_cache_parts = [v_cache] k_cache_quant_scale_parts = [k_cache_quant_scale] v_cache_quant_scale_parts = [v_cache_quant_scale] if parts_num > kv_parts_num: repeate_num = parts_num // kv_parts_num k_cache_parts = repeat_elements(k_cache_parts, repeate_num) v_cache_parts = repeat_elements(v_cache_parts, repeate_num) k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num) v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num) for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip( q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts, alibi_slopes_parts): tmo.single_query_cached_kv_attn( q_value, k_cache_value.contiguous(), v_cache_value.contiguous() if v_cache_value is not None else None, out_value, block_tables, context_lens, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value, max_contxt_len, windows_size_left, windows_size_right, softmax_scale, return_lse, head_size_v=head_size_v, compute_dtype=compute_dtype) return(None, None) # TODO(liangxuegang): to fix return (output, lse) def reshape_paged_cache( k: torch.Tensor, v: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor ) -> None: tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping) def swap_blocks( dst: torch.Tensor, src: torch.Tensor, block_mapping: torch.Tensor ) -> None: # FIXME: Remove this conversion after # tmo.swap_blocks support block_mapping tensor. block_mapping = block_mapping.tolist() block_mapping = {src: dst for src, dst in block_mapping} return tmo.swap_blocks(dst, src, block_mapping) def copy_blocks( k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], block_mapping: torch.Tensor ) -> None: # FIXME: Remove this conversion after # tmo.swap_blocks support block_mapping tensor. block_mapping = block_mapping.tolist() result_dict = {} for row in block_mapping: key = row[0] values = row[1:] if key in result_dict: result_dict[key].extend(values) else: result_dict[key] = values return tmo.copy_blocks(k_caches, v_caches, result_dict) def active( input: torch.Tensor, act_mode: str, is_gated: bool ) -> torch.Tensor: return tmo.active(input, act_mode, is_gated) def fused_moe(hidden_states: torch.Tensor, gating_output: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, bias1: Optional[torch.Tensor], bias2: Optional[torch.Tensor], residual: Optional[torch.Tensor], input_smooth: Optional[torch.Tensor], act_smooth: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], topk: int, renormalize: bool, gated: bool, act_mode: str, start_expert_id: int = 0, block_n: int = 0, cncl_comm: int = 0, avg_moe: bool=False, class_reduce_weight: Optional[torch.Tensor] = None, class_expert_id: Optional[torch.Tensor] = None, w1_quant_flag: Optional[List] = None, w2_quant_flag: Optional[List] = None, world_size: int = 0, shared_expert_num: int = 0, parallel_mode: str = 'ep'): dtype = hidden_states.dtype ori_input_shape = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) tokens = hidden_states.size(0) gating_output = gating_output.reshape(-1, gating_output.size(-1)) residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None expert_num = gating_output.size(-1) expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) per_token_sq = False # check quant check_list = [input_smooth, act_smooth, w1_scale, w2_scale] if all(x is not None for x in check_list): per_token_sq = True if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): raise ValueError( "input_smooth, act_smooth, w1_scale and w2_scale must be " "present and absent at the same time." ) # softmax_topk reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalize) # append shared if shared_expert_num > 0: reduce_weight, expert_id = tmo.moe_append_shared_expert(reduce_weight, expert_id, expert_num, shared_expert_num, world_size, parallel_mode) if parallel_mode == "ep": avg_shared_expert_num = (world_size + shared_expert_num - 1) // world_size expert_num += avg_shared_expert_num * world_size else: expert_num += shared_expert_num if avg_moe: n_tokens = hidden_states.shape[0] reduce_weight = class_reduce_weight[:n_tokens] expert_id = class_expert_id[:n_tokens] # gen_idx expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num) if per_token_sq: quant_input, input_scale = tmo.moe_quantize(hidden_states, input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx, cusum_token_count[start_expert_id].unsqueeze(0)) else: expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, cusum_token_count, start_expert_id, expert_size) # group gemm if per_token_sq: gemm1_out = tmo.smooth_quant_group_gemm(quant_input, w1, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag) else: gemm1_out = tmo.group_gemm(expand_hidden_states, w1, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, tokens) if per_token_sq: quant_input = quant_input[:, :gemm1_out.shape[-1] // 2] if gated else quant_input[:, :gemm1_out.shape[-1]] input_scale = input_scale[:gemm1_out.shape[0]] quant_input, input_scale = tmo.moe_quantize(gemm1_out, act_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], output=quant_input, output_scale=input_scale, act_mode=act_mode, is_gated=gated) else: act_out = gemm1_out[:, :gemm1_out.shape[-1] // 2] if gated else gemm1_out act_out = tmo.moe_active(gemm1_out, act_mode, gated, act_out, bias1, cusum_token_count, start_expert_id, expert_size) if cncl_comm > 0: raise ValueError("not support communication and computing fusion currently.") else: if per_token_sq: gemm2_out = tmo.smooth_quant_group_gemm(quant_input, w2, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag) else: gemm2_out = tmo.group_gemm(act_out, w2, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, tokens) output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx, residual, cusum_token_count, start_expert_id, expert_size, bias2) return output.reshape(ori_input_shape) def matmul( a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None, act_mode: str = 'none', alpha: float = 1.0, beta: float = .0 ) -> torch.Tensor: return tmo.matmul(a, b, bias, c, act_mode, alpha, beta) def weight_only_quant_matmul( a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor = None, bias: torch.Tensor = None, c: torch.Tensor = None, act_mode: str = "none", quant_bit_size: int = 8, alpha: float = 1.0, beta: float = 1.0 ) -> torch.Tensor: assert False, "[weight_only_quant_matmul] is deprecated." def smooth_quant_matmul( a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor, dtype: torch.dtype, bias: torch.Tensor = None, c: torch.Tensor = None, act_mode: str = "none", alpha: float = 1.0, beta: float = 1.0, use_hp_active: bool = False, b_quant_bit_size: int = 8, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: return tmo.scaled_matmul(a, b, a_scale, b_scale, dtype, bias, c, act_mode, b_quant_bit_size, alpha, beta, use_hp_active) def per_token_smooth_quantize(x: torch.Tensor, smooth: torch.Tensor, zero: torch.Tensor = None, token_count: torch.Tensor = None, act_mode: str = "none", active_coef: float = 1.0, is_gated: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if act_mode == "none": is_gated = False if token_count is None: output, output_scale = tmo.scaled_quantize(x, smooth, zero, None, torch.int8, "dynamic_per_token", act_mode, active_coef, is_gated) else: output, output_scale = tmo.moe_quantize(x, smooth, zero, token_count, None, None, None, None, True, act_mode, active_coef, is_gated) return (output, output_scale) def quantize( x: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor = None ) -> torch.Tensor: assert False, "[quantize] is deprecated." def quant_to_paged_cache( k: torch.Tensor, v: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k_cache_quant_scale: torch.Tensor, v_cache_quant_scale: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: if k_cache is not None and k_cache.dtype == torch.uint8: k_cache = k_cache.view(torch.float8_e4m3fn) if v_cache is not None and v_cache.dtype == torch.uint8: v_cache = v_cache.view(torch.float8_e4m3fn) return tmo.quant_to_paged_cache( k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping ) def advance_step(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, slot_mapping: torch.Tensor, block_tables: torch.Tensor, TILE_SIZE: int = 64) -> None: """ Advance a step on MLU for existing inputs for a multi-step runner, which will update input_tokens/seq_lens/input_positions/slot_mapping inplace. """ def verify_tensor( name: str, tensor: torch.Tensor, size_0: int, size_1: int, dtype: torch.dtype, ): """ Auxiliary function to check whether input is valid. """ size_0_cond = (size_0 == -1 or tensor.size(0) == size_0) size_1_cond = (size_1 == -1 or tensor.size(1) == size_1) if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype): raise ValueError( f"The input to advance_step is invalid with tensor name = {name}, " f"shape = {tensor.shape}, " f"is_cont = {tensor.is_contiguous()}, " f"type = {tensor.dtype}, " f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}" ) verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64) verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64) verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32) verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32) verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32) verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32) grid = (math.ceil(num_queries / TILE_SIZE), ) _triton_advance_step[grid](input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, block_tables.stride(0), num_seqs, num_queries, block_size, TILE_SIZE) #Moe inner kernels def moe_softmax_topk(input: torch.Tensor, topk: int, normalize: bool = False, num_expert_group: int = -1, topk_group: int = 0, mask: Optional[torch.Tensor] = None, normed_by : str = "topk_logit", route_scale : float = 1.0, reduce_weight: Optional[torch.Tensor] = None, expert_id: Optional[torch.Tensor] = None, score_bias: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]: return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group, mask, normed_by, route_scale, reduce_weight, expert_id, score_bias) def moe_sigmoid_topk(input: torch.Tensor, topk: int, normalize: bool = False, num_expert_group: int = -1, topk_group: int = 0, route_scale: float = 1.0, score_bias: Optional[torch.Tensor] = None, reduce_weight: Optional[torch.Tensor] = None, expert_id: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]: return tmo.moe_sigmoid_topk(input, topk, normalize, num_expert_group, topk_group, route_scale = route_scale, score_bias = score_bias, reduce_weight=reduce_weight, expert_id=expert_id) def moe_softplus_topk( input: torch.Tensor, topk: int, input_ids: Optional[torch.Tensor] = None, tid2eid: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, route_scale: float = 1.0, reduce_weight: Optional[torch.Tensor] = None, expert_id: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: return tmo.moe_softplus_topk( input, topk, input_ids, tid2eid, bias, route_scale, reduce_weight, expert_id, ) def moe_gen_idx(expert_id: torch.Tensor, expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return tmo.moe_gen_idx(expert_id, expert_num) def moe_expand_input(input: torch.Tensor, gather_idx: torch.Tensor, cusum_token_count: Optional[torch.Tensor] = None, start_expert_id: int = 0, expert_size: int = 0) -> torch.Tensor: return tmo.moe_expand_input(input, gather_idx, cusum_token_count, start_expert_id, expert_size) def moe_active(input: torch.Tensor, act_mode: str, is_gated: bool, output: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cusum_token_count: Optional[torch.Tensor] = None, start_expert_id: int = 0, expert_size: int = 0) -> torch.Tensor: return tmo.moe_active(input, act_mode, is_gated, output, bias, cusum_token_count, start_expert_id, expert_size) def group_gemm(a: torch.Tensor, b: torch.Tensor, m_list: torch.Tensor, expand_idx: Optional[torch.Tensor], c: Optional[torch.Tensor], alpha: Optional[torch.Tensor], beta: Optional[torch.Tensor], max_m: int = 0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: return tmo.group_gemm(a, b, m_list, expand_idx, c, alpha, beta, max_m, d=output) def smooth_quant_group_gemm(a: torch.Tensor, b: torch.Tensor, m_list: torch.Tensor, expand_idx: Optional[torch.Tensor], c: Optional[torch.Tensor], alpha: Optional[torch.Tensor], beta: Optional[torch.Tensor], a_scale: torch.Tensor, b_scale: torch.Tensor, dtype, max_m: int = 0, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, d=output) def moe_combine_result(input: torch.Tensor, reduce_weight: torch.Tensor, gather_ids: torch.Tensor, residual: Optional[torch.Tensor], cusum_token_count: Optional[torch.Tensor], start_expert_id: int, expert_size: int, bias: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: return tmo.moe_combine_result(input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias, output=output) def moe_quantize(x: torch.Tensor, smooth: torch.Tensor, zero: Optional[torch.Tensor] = None, token_count: Optional[torch.Tensor] = None, gather_index: Optional[torch.Tensor] = None, gather_index_start_position: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, dynamic_quant: bool = True, act_mode: str = "none", active_coef: float = 1.0, is_gated: bool = False, quant_type: torch.dtype = torch.int8 ) -> Tuple[torch.Tensor, torch.Tensor]: return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position, output, output_scale, dynamic_quant, act_mode, active_coef, is_gated, quant_type) def dequant_from_paged_cache(key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], key_cache_quant_scale: torch.Tensor, value_cache_quant_scale: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, context_seq_offset: Optional[torch.Tensor], block_tables: torch.Tensor, quant_mode: int = 0, quant_bit: int = 8) -> None: tmo.dequant_from_paged_cache( key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, context_lengths, max_context_len, context_seq_offset, block_tables, quant_mode, quant_bit) def random_sample( probs: torch.Tensor, is_gumbel_max: bool, generators: dict[int, torch.Generator], ) -> torch.Tensor: return tmo.random_sample(probs, is_gumbel_max, generators) def rejection_sample(draft_token_ids: torch.Tensor, num_draft_tokens: torch.Tensor, cu_num_draft_tokens: torch.Tensor, draft_probs: torch.Tensor, target_probs: torch.Tensor, bonus_token_ids: torch.Tensor, uniform_rand: torch.Tensor, uniform_probs: torch.Tensor, max_spec_len: int, high_acc: bool = True) -> torch.Tensor: return tmo.rejection_sample( draft_token_ids, num_draft_tokens, cu_num_draft_tokens, draft_probs, target_probs, bonus_token_ids, uniform_rand, uniform_probs, max_spec_len, high_acc) def apply_topkp_v2(logits: torch.Tensor, index_in: torch.Tensor, temperature_list: torch.Tensor, minp_list: torch.Tensor, topk_list: torch.Tensor, topp_list: torch.Tensor, logits_out: Optional[torch.Tensor] = None, sorted_logits_out: Optional[torch.Tensor] = None, index_out: Optional[torch.Tensor] = None, true_select_len: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return tmo.apply_topkp_v2(logits, index_in, temperature_list, minp_list, topk_list, topp_list, logits_out, sorted_logits_out, index_out, true_select_len) def scaled_quantize( input: torch.Tensor, scale: Optional[torch.Tensor] = None, zero: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, quant_type: torch.dtype = torch.int8, quant_mode: str = "dynamic_per_token", act_mode: str = "none", active_coef: float = 1.0, is_gated: bool = False ) -> Tuple[torch.Tensor]: """ Apply activation and quantization to the input tensor x. Args: x (torch.Tensor): The tensor to be quantized, shape is (..., C), must be continuous between 0 and -2 dimensions. scale (Optional[torch.Tensor], optional): The scale multipled to the input tensor. Shape is (C) or (1). zero (Optional[torch.Tensor], optional): Not supported, must pass None. scale_ub (Optional[torch.Tensor], optional): The output_scale upper bound. Take effect only if quant_type == torch.float8_e4m3fn and quant_mode == "dynamic_per_token". quant_type (optional): Output data type, can be torch.int8, torch.float8_e4m3fn. Defaults to torch.int8. quant_mode (str, optional): quantize mode, which can be "dynamic_per_token", "dynamic_per_tensor", "static_per_tensor" and "static_per_channel". Defaults to "dynamic_per_token". act_mode (str): The mode of activation, must be "none", "gelu", "silu", "swish". active_coef(float): The coefficient used in the swish activation. Default is 1.0. is_gated (bool): A boolean parameter that indicates whether a gating mechanism is applied. It only takes effect when act_mode is not "none". Type: input: float, half, bfloat16. scale: float. scale_ub: float. act_mode: str active_coef: float is_gated: bool Returns: Tuple[torch.Tensor]: Returns (output, output_scale) if quant_mode is "dynamic_per_token" or "dynamic_per_tensor", otherwise returns output only. """ return tmo.scaled_quantize(input, scale, zero, scale_ub, quant_type, quant_mode, act_mode, active_coef, is_gated) def scaled_matmul(a: torch.Tensor, b: torch.Tensor, a_scale: Optional[torch.Tensor], b_scale: torch.Tensor, output_dtype: torch.dtype, bias: torch.Tensor = None, c: torch.Tensor = None, act_mode: str = "none", quant_bit_size: int = 8, alpha: float = 1.0, beta: float = 1.0, use_hp_active: bool = False, a_quant_bit_size: int = 8, a_calib: Optional[torch.Tensor] = None, b_calib: Optional[torch.Tensor] = None,): """ Perform quantized matrix multiplication on tensor a and b. Args: a (torch.Tensor): Shape is (M, K). b (torch.Tensor): If quant_bit_size = 8, shape is (N, K). If quant_bit_size = 4, shape is (N, K//2). a_scale (Optional[torch.Tensor]): Shape can be (M). b_scale (torch.Tensor): If use groupwise quantization, shape must be (N, group_num), data type must be the same as a; otherwise shape must be (N), data type must be float. output_dtype (torch.dtype): Specify the data type of output, must be torch.half or torch.bfloat16. bias (torch.Tensor, optional): Shape is (N). c (torch.Tensor, optional): Shape is (M, N). act_mode (str, optional): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. If use groupwise quantization, act_mode must be 'none'. quant_bit_size (int, optional): The data format of b. Defaults to 8. alpha (float, optional): coefficient of acted. Defaults to 1.0. beta (float, optional): coefficient of c. Defaults to 1.0. use_hp_active (bool, optional): Describing the algorithm that used in the implementation of the activation function. When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation. Defaults to False. a_quant_bit_size(int, optional):The data format of a. Defaults to -1. a_calib (Optional[torch.Tensor]): The calibration of a, shape can be (M, 2). b_calib (Optional[torch.Tensor]): The calibration of b, shape can be (M, 2). Type: a: int8, half, bfloat16, float8_e4m3fn, int4X2 a_scale: float b: int8, float8_e4m3fn, int4X2 b_scale: float, half, bfloat16 bias: half, float, bfloat16 c: half, float, bfloat16 output: half, bfloat16 a_calib: float b_calib: float Returns: A tensor with the shape of (M, N). """ return tmo.scaled_matmul(a, b, a_scale, b_scale, output_dtype, bias, c, act_mode, quant_bit_size, alpha, beta, use_hp_active, a_quant_bit_size, a_calib, b_calib,) def fused_mla_kv(kv: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, position_id: torch.Tensor, gamma: torch.Tensor, kv_cache: torch.Tensor, kv_cache_scale: Optional[torch.Tensor], slot_mapping: Optional[torch.Tensor], cache_bs_id: Optional[torch.Tensor] = None, cache_seq_offset: Optional[torch.Tensor] = None, is_paged_cache: bool = True, eps: float = 1e-5, interleaved: bool = True): quant_mode = "static_per_channel" if kv_cache_scale is None else "dynamic_per_token" return tmo.fused_mla_kv( kv, sin, cos, position_id, gamma, kv_cache, kv_cache_scale, slot_mapping, cache_bs_id, cache_seq_offset, quant_mode=quant_mode, is_paged_cache=is_paged_cache, eps=eps, interleaved=interleaved, ) def fused_mla_q(q: torch.Tensor, gamma: torch.Tensor, smooth_quant_scale: torch.Tensor, weight_b: torch.Tensor, weight_b_scale: torch.Tensor, weight_c: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, position_id: torch.Tensor, output: Optional[torch.Tensor] = None, eps: float = 1e-6, interleaved: bool = True, output_quant_mode: str = 'none', output_scale: Optional[torch.Tensor] = None, output_norm: Optional[torch.Tensor] = None) -> torch.Tensor: return tmo.fused_mla_q( q, gamma, smooth_quant_scale, weight_b, weight_b_scale, weight_c, sin, cos, position_id, output, eps, interleaved, output_quant_mode, output_scale, store_norm=(output_norm is not None), output_norm= output_norm, ) def gather_cache( kv_cache: List[torch.Tensor], # [[1, num_blocks, num_kv_heads, block_size, head_size] # [1, num_blocks, num_kv_heads, block_size] if kv_cache_dtype=int8] dst: torch.Tensor, # [tot_tokens, entrys...] block_table: torch.Tensor, # [batch, block_indices] cu_seq_lens: torch.Tensor, # [batch+1] batch_size: int, seq_starts: torch.Tensor = None, # Optional: [batch] kv_cache_dtype: str = 'auto', ) -> None: """ Gathers sequences from src_cache into dst based on block_table and cu_seq_lens. Args: src_cache: Source KV cache tensor of shape [[1, num_blocks, num_kv_heads, block_size, head_size], [1, num_blocks, num_kv_heads, block_size] if cache_dtype=int8]. dst: Destination tensor of shape [tot_tokens, entrys...]. block_table: Tensor of shape [batch, block_indices] mapping sequences to blocks. cu_seq_lens: Tensor of shape [batch+1] with cumulative sequence lengths. batch_size: Number of sequences in the batch. seq_starts: Optional tensor of shape [batch] for block index offsets. """ assert len(kv_cache) > 0 and kv_cache[0].numel() > 0, "kv cache can't be empty in gather_cache" src_cache = kv_cache[0][0] # Validate inputs assert src_cache.device == dst.device == block_table.device == cu_seq_lens.device, \ "All tensors must be on the same device" assert block_table.dtype == torch.int32, "block_table must be int32" assert cu_seq_lens.dtype == torch.int32, "cu_seq_lens must be int32" quant_kv_cache = kv_cache_dtype != 'auto' if not quant_kv_cache: assert src_cache.dtype == dst.dtype, "src_cache and dst must have the same dtype when no quantized" if seq_starts is not None: assert seq_starts.dtype == torch.int32, "seq_starts must be int32" assert seq_starts.device == src_cache.device, "seq_starts must be on the same device" # Extract dimensions num_blocks, num_kv_heads, block_size, head_size = src_cache.shape # When using MLA during decode it becomes MQA, the num_kv_heads is fixed to 1, # so src_cache can be view to [num_blocks, block_size, head_size] assert num_kv_heads == 1, "mla force num_kv_heads to 1" src_cache = src_cache.view(num_blocks, block_size, -1) entry_shape = src_cache.shape[2:] # ENTRIES... tot_tokens = cu_seq_lens[-1] assert tot_tokens > 0, "tot_tokens should > 0" assert tot_tokens <= dst.shape[0], "tot_tokens should <= dst.shape[0]" dst_cache = dst[:tot_tokens] # Ensure cu_seq_lens matches batch_size assert cu_seq_lens.size(0) == batch_size + 1, "cu_seq_lens must have batch_size + 1 elements" # Compute sequence lengths seq_lens = cu_seq_lens[1:] - cu_seq_lens[:-1] # [BATCH] tot_blocks_per_seq = (seq_lens + block_size - 1) // block_size # ceil_div # Handle seq_starts offset block_offsets = torch.zeros(batch_size, dtype=torch.int32, device=src_cache.device) if seq_starts is not None: block_offsets = seq_starts // block_size # Flatten src_cache for easier indexing: [NUM_BLOCKS * BLOCK_SIZE, ENTRIES...] src_flat = src_cache.view(num_blocks * block_size, *entry_shape) # Prepare output indices dst_indices = [] for bid in range(batch_size): seq_len = seq_lens[bid] if seq_len <= 0: continue seq_start = cu_seq_lens[bid] tot_blocks = tot_blocks_per_seq[bid] offset = block_offsets[bid] # Compute block indices for this sequence block_ids = block_table[bid, offset:offset + tot_blocks] # Compute token indices within blocks token_indices = torch.arange(seq_len, device=src_cache.device) block_indices = token_indices // block_size within_block = token_indices % block_size # Map to src_flat indices src_indices = block_ids[block_indices] * block_size + within_block dst_indices.append(src_indices) # Concatenate all indices dst_indices = torch.cat(dst_indices) # Gather data dst_flat = src_flat[dst_indices] if quant_kv_cache: src_cache_scale = kv_cache[1][0] src_scale_flat = src_cache_scale.view(num_blocks * block_size) dst_scale_flat = src_scale_flat[dst_indices] dst_flat = dst_flat * dst_scale_flat.unsqueeze(-1) dst_cache.view(-1, *entry_shape).copy_(dst_flat.view(tot_tokens, *entry_shape)) def merge_attn_states( output: torch.Tensor, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: Optional[torch.Tensor] = None, ) -> None: """ Merges partial attention states (prefix and suffix) into a single output. Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005. Args: output: Output tensor of shape [num_tokens, num_query_heads, head_size]. prefix_output: Prefix attention output, same shape as output. prefix_lse: Prefix log-sum-exp, shape [num_query_heads, num_tokens]. suffix_output: Suffix attention output, same shape as output. suffix_lse: Suffix log-sum-exp, same shape as prefix_lse. output_lse: Optional output log-sum-exp, same shape as prefix_lse. """ # Input validation assert output.shape == prefix_output.shape == suffix_output.shape, \ "Output and input tensors must have the same shape" assert prefix_lse.shape == suffix_lse.shape, \ "Prefix and suffix LSE tensors must have the same shape" if output_lse is not None: assert output_lse.shape == prefix_lse.shape, \ "Output LSE must have the same shape as input LSE tensors" # Handle inf values (replace inf with -inf for consistency) p_lse = torch.where( prefix_lse == float('inf'), torch.tensor(float('-inf'), device=prefix_lse.device), prefix_lse ) s_lse = torch.where( suffix_lse == float('inf'), torch.tensor(float('-inf'), device=suffix_lse.device), suffix_lse ) # Compute maximum LSE for numerical stability max_lse = torch.maximum(p_lse, s_lse) # Shape: [num_query_heads, num_tokens] # Normalize LSE terms p_lse = p_lse - max_lse # Shape: [num_query_heads, num_tokens] s_lse = s_lse - max_lse # Shape: [num_query_heads, num_tokens] # Compute sum of exponentials out_se = torch.exp(p_lse) + torch.exp(s_lse) # Shape: [num_query_heads, num_tokens] # Compute output_lse if provided if output_lse is not None: output_lse.copy_(torch.log(out_se) + max_lse) # Compute scaling factors p_scale = torch.exp(p_lse) / out_se # Shape: [num_query_heads, num_tokens] s_scale = torch.exp(s_lse) / out_se # Shape: [num_query_heads, num_tokens] # Reshape scales for broadcasting p_scale = p_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1] s_scale = s_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1] # Transpose outputs to match scaling dimensions prefix_output = prefix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size] suffix_output = suffix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size] # Compute merged output out = prefix_output * p_scale + suffix_output * s_scale # Shape: [num_query_heads, num_tokens, head_size] # Transpose back and store in output output.copy_(out.permute(1, 0, 2)) # Shape: [num_tokens, num_query_heads, head_size] def moe_all2all_create(dispatch_token_byte: int, combine_token_byte: int, max_expert_num: int, max_token_num: int, rank: int, nrank: int) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Create the handle of MOE All-to-All communication. API call order: 1.Call torch_mlu_ops.moe_all2all_create(...) to obtain the CNCLEP handle and buffer tensor for All-to-All communication. Only needs to be done once. 2.Gather all_exchange_info by performing an All-Gather operation on exchange_info across nrank processes. Only needs to be done once. 3.Call torch.distributed.barrier() to ensure step 2 finish. Only needs to be done once. 4.Call torch_mlu_ops.moe_all2all_init(...) to configure the all_exchange_info into the handle. Only needs to be done once. 5.Call torch_mlu_ops.moe_all2all_dispatch(...) to route tokens to their designated experts. 6.Call torch_mlu_ops.moe_all2all_combine(...) to restore tokens to their original locations. 7.Call torch_mlu_ops.moe_all2all_destroy(...) to release the CNCLEP handle. Only needs to be done once. Args: dispatch_token_byte (int): Byte size of a single token for dispatch All-to-All operation. combine_token_byte (int): Byte size of a single token for combine All-to-All operation. max_expert_num (int): Maximum number of experts participating in the All-to-All operation. max_token_num (int): Maximum number of tokens to be processed. rank (int): Rank ID of the current process [0~nrank-1]. nrank (int): Total number of processes in the distributed group. Return: A tuple of (handle, exchange_info_size, exchange_info, dispatch_send, dispatch_recv, combine_send and combine_recv). handle: The CNCLEP handle with type of integer. exchange_info_size: The size of exchange_info. exchange_info: CPU tensor, shape is [exchange_info_size], and data type is torch.int8. dispatch_send: MLU tensor, shape is [max_token_num * dispatch_token_byte], and data type is torch.int8. dispatch_recv: MLU tensor, shape is [nrank * max_token_num * dispatch_token_byte], and data type is torch.int8. combine_send: MLU tensor, shape is [max_token_num * combine_token_byte], and data type is torch.int8. combine_recv: MLU tensor, shape is [nrank * max_token_num * combine_token_byte], and data type is torch.int8. """ return tmo.moe_all2all_create(dispatch_token_byte, combine_token_byte, max_expert_num, max_token_num, rank, nrank) def moe_all2all_init(handle: int, all_exchange_info: torch.Tensor) -> None: tmo.moe_all2all_init(handle, all_exchange_info) def moe_all2all_destroy(handle: int) -> None: tmo.moe_all2all_destroy(handle) def moe_all2all_dispatch(handle: int, token_byte: int, token_num: int, send_layout: torch.Tensor, send_token_num: torch.Tensor, recv_layout: torch.Tensor, recv_token_num: torch.Tensor, send_token: Optional[torch.Tensor] = None, recv_token: Optional[torch.Tensor] = None, ) -> None: tmo.moe_all2all_dispatch(handle, token_byte, token_num, send_layout, send_token_num, recv_layout, recv_token_num, send_token, recv_token) def moe_all2all_combine(handle: int, token_byte: int, token_num: int, send_src_layout: torch.Tensor, send_dst_layout: torch.Tensor, send_token: Optional[torch.Tensor] = None, recv_token: Optional[torch.Tensor] = None, ) -> None: tmo.moe_all2all_combine(handle, token_byte, token_num, send_src_layout, send_dst_layout, send_token, recv_token) def gather_split(input: torch.Tensor, gather_index: torch.Tensor, valid_token_num: torch.Tensor, output1: torch.Tensor, output2: Optional[torch.Tensor] = None) -> None: tmo.gather_split(input, gather_index, valid_token_num, output1, output2) def moe_all2all_gen_send_layout(token_count: torch.Tensor, nrank: int) -> torch.Tensor: return tmo.moe_all2all_gen_send_layout(token_count, nrank) def moe_all2all_gen_gather_index(token_num: torch.Tensor, pad_num: int, return_cusum_token_count: bool = False): if not return_cusum_token_count: gather_by_expert_index, gather_by_rank_index, token_count, token_sum = \ tmo.moe_all2all_gen_gather_index(token_num, pad_num) return gather_by_expert_index, gather_by_rank_index, token_count, token_sum else: gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count = \ tmo.moe_all2all_gen_gather_index(token_num, pad_num, return_cusum_token_count=True) return gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count def reshape_from_cache( key: torch.Tensor, value: Optional[torch.Tensor], key_cache: torch.Tensor, value_cache: Optional[torch.Tensor], context_lengths: torch.Tensor, max_context_len: int, context_seq_offset: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None, cache_seq_offset: Optional[torch.Tensor] = None, ) -> None: tmo.reshape_from_cache( key=key, value=value, key_cache=key_cache, value_cache=value_cache, context_lengths=context_lengths, max_context_len=max_context_len, context_seq_offset=context_seq_offset, block_tables=block_tables, cache_seq_offset=cache_seq_offset, ) def masked_indexer_select_paged_kv(query: torch.Tensor, k_cache: torch.Tensor, weights: torch.Tensor, kv_cache_block_table: torch.Tensor, cu_seq_q_lens: Optional[torch.Tensor], cu_seq_k_lens: Optional[torch.Tensor], k_context_lens: Optional[torch.Tensor], k_cache_block_table: Optional[torch.Tensor], is_prefill: bool, index_topk: int, kv_cache_block_size: int, softmax_scale: float, q_scale: Optional[torch.Tensor] = None, k_scale_cache: Optional[torch.Tensor] = None, sparse_block_table: Optional[torch.Tensor] = None, sparse_context_lens: Optional[torch.Tensor] = None): tmo.masked_indexer_select_paged_kv(query=query, k_cache=k_cache, weights=weights, kv_cache_block_table=kv_cache_block_table, cu_seq_q_lens=cu_seq_q_lens, cu_seq_k_lens=cu_seq_k_lens, k_context_lens=k_context_lens, k_cache_block_table=k_cache_block_table, is_prefill=is_prefill, index_topk=index_topk, kv_cache_block_size=kv_cache_block_size, softmax_scale=softmax_scale, q_scale=q_scale, k_scale_cache=k_scale_cache, sparse_block_table=sparse_block_table, sparse_context_lens=sparse_context_lens) def masked_indexer_select_paged_kv_prefill( query: torch.Tensor, key_value: torch.Tensor, weights: torch.Tensor, kv_cache_block_table: torch.Tensor, cu_seq_q_lens: torch.Tensor, cu_seq_k_lens: torch.Tensor, index_topk: int, kv_cache_block_size: int, softmax_scale: float, q_scale: Optional[torch.Tensor] = None, k_scale_cache: Optional[torch.Tensor] = None, sparse_block_table: Optional[torch.Tensor] = None, sparse_context_lens: Optional[torch.Tensor] = None, kv_cache_block_table_offset: Optional[torch.Tensor] = None, compress_ratio: int = 1, ): return tmo.masked_indexer_select_paged_kv( query=query, k_cache=key_value, weights=weights, kv_cache_block_table=kv_cache_block_table, cu_seq_q_lens=cu_seq_q_lens, cu_seq_k_lens=cu_seq_k_lens, k_context_lens=None, k_cache_block_table=None, is_prefill=True, index_topk=index_topk, kv_cache_block_size=kv_cache_block_size, softmax_scale=softmax_scale, q_scale=q_scale, k_scale_cache=k_scale_cache, sparse_block_table=sparse_block_table, sparse_context_lens=sparse_context_lens, kv_cache_block_table_offset=kv_cache_block_table_offset, compress_ratio=compress_ratio, is_score_float=True, ) def masked_indexer_select_paged_kv_decode( query: torch.Tensor, k_cache: torch.Tensor, weights: torch.Tensor, kv_cache_block_table: torch.Tensor, k_context_lens: Optional[torch.Tensor], k_cache_block_table: Optional[torch.Tensor], index_topk: int, kv_cache_block_size: int, softmax_scale: float, q_scale: Optional[torch.Tensor] = None, k_scale_cache: Optional[torch.Tensor] = None, sparse_block_table: Optional[torch.Tensor] = None, sparse_context_lens: Optional[torch.Tensor] = None, kv_cache_block_table_offset: Optional[torch.Tensor] = None, compress_ratio: int = 1, ): query_len = query.shape[1] #k_context_lens = k_context_lens // compress_ratio return tmo.masked_indexer_select_paged_kv( query=query, k_cache=k_cache, weights=weights, kv_cache_block_table=kv_cache_block_table, cu_seq_q_lens=None, cu_seq_k_lens=None, k_context_lens=k_context_lens, k_cache_block_table=k_cache_block_table, is_prefill=False, index_topk=index_topk, kv_cache_block_size=kv_cache_block_size, softmax_scale=softmax_scale, q_scale=q_scale, k_scale_cache=k_scale_cache, sparse_block_table=sparse_block_table, sparse_context_lens=sparse_context_lens, kv_cache_block_table_offset=kv_cache_block_table_offset, compress_ratio=compress_ratio, is_score_float=True, ) def concat_block_table( first_block_table: torch.Tensor, first_context_lens: torch.Tensor, second_block_table: torch.Tensor, second_context_lens: torch.Tensor, new_block_table: Optional[torch.Tensor] = None, new_context_lens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Concatenate two different block tables, return the concatenated result. Math: new_context_lens = first_context_lens + second_context_lens total_seq = first_context_lens.size(0) for i in range(total_seq): new_block_table[i, :first_context_lens[i]] = first_block_table[i, :first_context_lens[i]] new_block_table[i, first_context_lens[i]:first_context_lens[i]+second_context_lens[i]] = second_block_table[i, :second_context_lens[i]] Args: first_block_table (torch.Tensor): The first block table of shape `[total_seq, first_max_blkn]`. first_context_lens (torch.Tensor): The context lens of the first block table of shape `[total_seq,]`. second_block_table (torch.Tensor): The second block table of shape `[total_seq, second_max_blkn]`. second_context_lens (torch.Tensor): The context lens of the second block table of shape `[total_seq,]`. new_block_table (Optional[torch.Tensor]): The new block table of shape `[total_seq, max_new_block_number]`. if not None, the max_new_block_number must be large enough for the concatenated block_table Default: `None`. new_context_lens (Optional[torch.Tensor]): The new context lens of shape `[total_seq,]`. Default: `None`. Returns: new_block_table (torch.Tensor): The concatenated block table of shape `[total_seq, max_new_block_number]`. new_context_lens (torch.Tensor): The new context lens of shape `[total_seq,]`, equals first_context_lens + second_context_lens Type: INT32 """ return tmo.concat_block_table( first_block_table, first_context_lens, second_block_table, second_context_lens, new_block_table, new_context_lens, ) def fused_mhc_post( x: torch.Tensor, # (N, D) float|bf16 residual: torch.Tensor, # (N, HC, D) float|bf16 post: torch.Tensor, # (N, HC) 固定为float comb: torch.Tensor, # (N, HC, HC) 固定为float compute_rms: bool, eps: float, output: torch.Tensor = None, # (N, HC, D) 同输入类型 output_rms = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Math: output = post * x + (comb * residual).sum(dim=1) output_rms = rsqrt(x.square().mean(dim=-1)) Args: x (torch.Tensor): Shape is [N, D]. residual (torch.Tensor): Shape is [N, HC, D]. post (torch.Tensor): Shape is [N, HC]. comb (torch.Tensor): Shape is [N, HC, HC]. compute_rms (bool): Whether to compute output_rms. eps (float): The eps of normalization. output (torch.Tensor, optional): Shape is [N, HC, D]. Defaults to None. output_rms (torch.Tensor, optional): Shape is [N]. Defaults to None. Returns: output, output_rms Limitation: D must be 4096. HC must be 4. """ out = tmo.fused_mhc_post( x, residual, post, comb, compute_rms, eps, output, output_rms, ) return out if compute_rms else (out, None) def fused_compress_multi_kv(kv: torch.Tensor, # (BS, D) float|bf16 score: torch.Tensor, # (BS, D) float|bf16 kv_state: torch.Tensor, # (max_B, coff * R, D) float score_state: torch.Tensor, # (max_B, coff * R, D) float batch_ids: torch.Tensor, # (B,) int32 cu_seqlens: torch.Tensor, # (B,) int32 ape: torch.Tensor, # (R, D) float max_seqlen:int, overlap: bool, compressed_kv: torch.Tensor # (BS, head_dim) float|bf16 ): tmo.fused_compress_multi_kv( kv = kv, score = score, kv_state = kv_state, score_state = score_state, cu_seqlens = cu_seqlens, batch_ids = batch_ids, ape = ape, max_seqlen = max_seqlen, overlap = overlap, compressed_kv = compressed_kv, ) def fused_compress_single_kv( kv: torch.Tensor, # (T, D) float|bf16 score: torch.Tensor, # (T, D) float|bf16 position: torch.Tensor, # (B,) int32 ape: torch.Tensor, # (ratio, D) float|bf16 kv_state: torch.Tensor, # (B, R, D) float|bf16 score_state: torch.Tensor, # (B, R, D) float|bf16 gamma: torch.Tensor, # (d) sin: torch.Tensor, # (-1, rope_dim) cos: torch.Tensor, # (-1, rope_dim) hadamard_matrix: Optional[torch.Tensor], # (d, d) slot_mapping: torch.Tensor, # (B,) int32 kv_cache: torch.Tensor, # (-1, BLKS, head_dim) bf16|int8|fp8 kv_cache_scale: Optional[torch.Tensor], # (-1, BLKS) float eps: float, overlap: bool, rotate: bool, state_idx: torch.Tensor, cu_query_len: torch.Tensor | None = None, ): """ Math: Args: kv (torch.Tensor): Shape is [B, S, D]. score (torch.Tensor): Shape is [B, S, D]. position (torch.Tensor): Shape is [B]. ape (torch.Tensor): Shape is [ratio, D]. kv_state (torch.Tensor): Shape is [max_B, R, D]. score_state (torch.Tensor): Shape is [max_B, R, D]. gamma (torch.Tensor): Shape is [head_dim]. sin (torch.Tensor): Shape is [table_len, rope_dim]. cos (torch.Tensor): Shape is [table_len, rope_dim]. hadamard_matrix (torch.Tensor): Shape is [head_dim, head_dim]. slot_mapping (torch.Tensor): Shape is [B]. kv_cache (torch.Tensor): Shape is [cache_len, block_size, hs]. kv_cache_scale (torch.Tensor): Shape is [cache_len, block_size]. eps (flost): The eps of normalization. overlap (bool): Whether to overlap. rotate (bool): Whether to rotate. Type: kv: BF16, FP32 score: same as kv position: INT32 ape: FP32 kv_state: FP32 score_state: FP32 gamma: same as kv sin: same as kv cos: same as kv hadamard_matrix: same as kv slot_mapping: INT32 kv_cache: BF16, FP32 kv_cache_scale: FP32 Returns: Only support inplace outputs, include kv_state, score_state, kv_cache, kv_cache_scale Note: coff = overlap + 1 D = coff * head_dim R = coff * ratio """ token_num, coff_dim = kv.shape # TODO: force user_tmo = 0 after supporting mtp. bsz = state_idx.numel() kv = kv.unsqueeze(1) score = score.unsqueeze(1) if kv_cache.dim() == 4: paged_num, head_num, block_size, head_dim = kv_cache.shape assert head_num == 1 kv_cache = kv_cache.view(paged_num, block_size, head_dim) return tmo.fused_compress_single_kv( kv=kv, score=score, position=position, state_ids=state_idx, ape=ape, kv_state=kv_state, score_state=score_state, gamma=gamma, sin=sin, cos=cos, hadamard_matrix=hadamard_matrix if rotate else None, slot_mapping=slot_mapping, kv_cache=kv_cache, kv_cache_scale=kv_cache_scale, eps=eps, overlap=overlap, ) def convertBlockTable(block_table, blks, incseq): if blks == 1: return block_table else: expanded = block_table.unsqueeze(1).repeat(1, blks) result = expanded * blks + incseq return result.flatten() def get_window_block_tables(window_size : int, block_size : int, #blocksize of block_table seq_k_lens: torch.Tensor, query_start_loc: torch.Tensor, block_table: Optional[torch.Tensor]=None, # shape (batch, max_blocks) window_block_tables:Optional[torch.Tensor]=None, # shape (total_seq, max_blocks) window_context_lens:Optional[torch.Tensor]=None): # shape (total_seq) tmo.get_window_block_tables(window_block_tables = window_block_tables, window_context_lens = window_context_lens, seq_k_lens = seq_k_lens, query_start_loc = query_start_loc, block_table = block_table, block_size = block_size, window_size = window_size,) def get_compress_block_tables(ratio: int, block_size: int, seq_k_lens: torch.Tensor, # k lens before compression, shape (batch) query_start_loc: torch.Tensor, # shape (batch+1) offset: torch.Tensor, # shape (batch) block_table: torch.Tensor, # shape (batch, max_blocks) compress_block_tables: torch.Tensor, # shape (total_seq, max_blocks) compress_context_lens: torch.Tensor): # shape (total_seq) tmo.get_compress_block_tables( compress_block_tables = compress_block_tables, compress_context_lens = compress_context_lens, seq_k_lens = seq_k_lens, query_start_loc = query_start_loc, offset = offset, block_table = block_table, block_size = block_size, ratio = ratio, ) def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, pre_scale: Optional[torch.Tensor] = None, hc_mult: int = 4, sinkhorn_iter: int = 20, eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return tmo.hc_split_sinkhorn( mixes = mixes, hc_scale = hc_scale, hc_base = hc_base, pre_scale = pre_scale, hc_mult = hc_mult, sinkhorn_iter = sinkhorn_iter, eps = eps, ) def fused_indexer_q(q: torch.Tensor, w_q: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, position_id: torch.Tensor, output: Optional[torch.Tensor] = None, hadamard_matrix: Optional[torch.Tensor] = None, w_q_scale: Optional[torch.Tensor] = None, output_quant_mode: str = 'none', output_scale: Optional[torch.Tensor] = None, interleaved: bool = True, rope_at_front: bool = True): return tmo.fused_indexer_q( q = q, w_q = w_q, sin = sin, cos = cos, position_id = position_id, output = output, hadamard_matrix = hadamard_matrix, w_q_scale = w_q_scale, output_quant_mode = output_quant_mode, output_scale = output_scale, interleaved = interleaved, rope_at_front = rope_at_front) def fused_mla_q_v2( input_q: torch.Tensor, gamma: torch.Tensor, smooth_quant_scale: Optional[torch.Tensor], weight_b: torch.Tensor, weight_b_scale: Optional[torch.Tensor], sin: torch.Tensor, cos: torch.Tensor, position_id: torch.Tensor, output: Optional[torch.Tensor] = None, eps: float = 1e-6, interleaved: bool = True, store_norm: bool = False, output_norm: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ This function applies MLA (Multi-head Latent Attention) v2 Query (Q) preprocessing. The fusion logic includes: RMSNorm -> Quant(Optional) -> MatMul -> RMSNorm -> RoPE. Math: qr = rmsnorm(input_q, gamma, eps) if quant: qr, q_scale = per_token_quant(norm_out, smooth_quant_scale) q = matmul(qr, q_scale, weight_b, weight_b_scale) q = q.reshape(batch, seq, n_local_heads, head_dim) q = rsqrt(q.square().mean(-1, keepdim=True) + eps) out = apply_rotary_embedding(q, sin, cos, position_id, interleaved) Args: input_q (torch.Tensor): The input latent query tensor. Shape is (batch, seq, q_lora_rank). gamma (torch.Tensor): The scaling parameter for the initial RMSNorm. Shape is (q_lora_rank). smooth_quant_scale (Optional[torch.Tensor]): Scale tensor for SmoothQuant migration. Can be None. Shape is (q_lora_rank). weight_b (torch.Tensor): The Q-projection weight tensor. Shape is (n_local_heads, head_dim, q_lora_rank). weight_b_scale (Optional[torch.Tensor]): The per-channel quantization scales for weight_b. Shape is (n_local_heads, head_dim). sin (torch.Tensor): Rotary embedding sine table. Shape is (max_rotary_seq_len, rotary_head_dim). cos (torch.Tensor): Rotary embedding cosine table. Shape is (max_rotary_seq_len, rotary_head_dim). position_id (torch.Tensor): Indices for the RoPE tables. Shape is (batch,). output (Optional[torch.Tensor]): Optional output tensor for the final processed Q. Shape is (batch, seq, n_local_heads, head_dim). eps (float): Small constant for RMSNorm numerical stability. Default: 1e-6. interleaved (bool): If True, apply interleaved rotary embedding, otherwise folded. Default: True. store_norm (bool): If True, the intermediate RMSNorm result (pre-MatMul) will be returned. Default: False. output_norm (Optional[torch.Tensor]): Optional tensor to store the intermediate RMSNorm result. Shape: (batch, seq, q_lora_rank). Type: input_q, gamma, sin, cos: bfloat16. weight_b: int8, same as input_q. weight_b_scale, smooth_quant_scale: float32. position_id: int32. output: same as input_q. Return: Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - If store_norm=False: output - If store_norm=True: (..., output_norm) is appended to the return. """ return tmo.fused_mla_q_v2( input_q=input_q, gamma=gamma, smooth_quant_scale=smooth_quant_scale, weight_b=weight_b, weight_b_scale=weight_b_scale, sin=sin, cos=cos, position_id=position_id, output=output, eps=eps, interleaved=interleaved, store_norm=store_norm, output_norm=output_norm, ) def update_compressor_states( kv_state, # (max_batch, (overlap+1)*ratio + K, dim) score_state, # (max_batch, (overlap+1)*ratio + K, dim) accept_tokens: torch.Tensor, # (bsz,) batch_to_kv_state: torch.Tensor, # (bsz,) positions: torch.Tensor, # (bsz,) cu_query_len: torch.Tensor, # (bsz+1,) overlap: bool, K: int ): bsz = batch_to_kv_state.numel() ratio = (kv_state.size(1) - K) // (overlap + 1) start_positions = positions[cu_query_len[:bsz]] end_positions = start_positions + accept_tokens for i in range(bsz): start_pos = start_positions[i] end_pos = end_positions[i] # Skip if sequence len does not exceed coff * ratio. if (overlap and end_pos < 2 * ratio) or (not overlap and end_pos < ratio): continue # Skip if compression condition does not meets. if (start_pos // ratio) == (end_pos // ratio) and start_pos % ratio != 0: continue state_idx = batch_to_kv_state[i] if overlap: length = end_pos - start_pos + start_pos % ratio else: length = end_pos % ratio start = ratio end = start + length if length == 0: continue kv_state[state_idx, :length] = kv_state[state_idx, start:end].clone() score_state[state_idx, :length] = score_state[state_idx, start:end].clone()