from __future__ import annotations from typing import List, Optional, Tuple, Union from enum import Enum import torch from torch import Generator from torch._C._distributed_c10d import ReduceOp from torch_vacc._vacc_libs import _torch_vacc def rms_norm( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, output=None ) -> torch.Tensor: r"""Root mean square(RMS) normalization of inputs over last dimension. vacc fused kernal for rms_norm, only for three dimensions of input and sigle dimemsion weight Args: input: input tensor with shape (batch, seq_len, hidden size) weight: weight tensor with shape (hidden size,) eps (float): small value to avoid division by zero. Default: 1e-6 Returns: Tensor: tensor after applying rms_norm """ # assert input.dim() == 3, "rms_norm only support the input with dim=3" if input.device.type == "vacc": return torch.ops.vacc.rms_norm_func(input, weight, eps, output) input_dtype = input.dtype input = input.to(torch.float32) variance = input.pow(2).mean(-1, keepdim=True) rsigma = torch.rsqrt(variance + eps) normalized_states = input * rsigma return weight * normalized_states.to(input_dtype) rms_norm.apply = rms_norm def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] x2 = x[..., 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) def RotaryPosEmbedding( q: torch.Tensor, k: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, offset: int = 0, mode: str = "neox", ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Performs: Apply rotary positional embedding to input tensor q/k in `sbhd` or s(b*h)d format, where s: sequence length b: batch size h: head num d: dim of each head Note: if cos/sin=None, vacc do RoPE in classical mode, generating cos/sin with arg 'base=10000' Args: q (Tensor): Input tensor T is of shape [s, b*h, d] k (Tensor): Input tensor T is of shape [s, b*h, d] where s is sequence lenth; b is batch size; h is heads; d is head dims cos optional(Tensor): Cached cosine of the rotary positional embedding tensor with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2] sin optional(Tensor): Cached sine of the rotary positional embedding tensor. with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2] offset (int): offset of position at cos/sin cache, default=0 mode (str): 'nexo': rotate half, 'gptj': rotate every two Returns: Tuple[Tensor, Tensor]: The tuple of q/k tensor after applying RoPE """ assert ( q.dim() == 3 and k.dim() == 3 ), f"the dim of q/k should be 3 but get q:{q.dim()}, k:{k.dim()}" assert mode in ["neox", "gptj"], "only support rope mode 'neox' or 'gptj'" assert q.dtype == k.dtype, "the dtype should be same" assert (cos == None and sin == None) or ( isinstance(cos, torch.Tensor) and isinstance(sin, torch.Tensor) ) assert q.device.type == "vacc" and k.device.type == "vacc" mode_ = 0 if mode == "gptj" else 1 if cos is not None and cos.size(0) != q.size(0): cos = cos[offset: q.shape[0] + offset, ...] if sin is not None and sin.size(0) != q.size(0): sin = sin[offset: q.shape[0] + offset, ...] # repeat last dim as same size with input tensor # if cos is not None and cos.numel() != 0: # assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1) # assert cos.dim() == 2 or cos.dim() == 3 # if cos.size(-1) * 2 == q.size(-1): # if mode_ == 0: # cos = cos.repeat_interleave(2, dim=-1) # sin = sin.repeat_interleave(2, dim=-1) # else: # cos = torch.cat([cos, cos], dim=-1) # sin = torch.cat([sin, sin], dim=-1) if cos is not None and cos.numel() != 0: assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1) assert cos.dim() == 2 or cos.dim() == 3 if cos.dim() == 2: cos = cos.unsqueeze(-2) sin = sin.unsqueeze(-2) return torch.ops.vacc.RotaryPosEmbedding_func(q, k, cos, sin, offset, mode_) RotaryPosEmbedding.apply = RotaryPosEmbedding def scaled_dot_product_attention( # same with torch define query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor = None, dropout_p: float = 0.5, is_causal: bool = False, scale: float = None, # extend is_train: bool = True, recompute: bool = False, flash_attention: bool = False, sm_scale: float = -1, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d], kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d]. support float16, bfloat16, float32 sq: sequence length of query sk: sequence length of key, value b: batch size h: head num d: dim of each head Args: query (Tensor): Input tensor T is of shape [sq, b*h, d] key (Tensor): Input tensor T is of shape [sk, b*h, d] value (Tensor): Input tensor T is of shape [sk, b*h, d] attn_mask (Tensor): masked bool tensor of shape [1, sq, sk] or [sq, sk] dropout_p (float): the probability of dropout is_causal (bool): accelerate compute when mask is causal type scale (float): not use, default is 1/sqrt(dim) is_train (bool): train mode or eval mode recompute (bool): whether to recompute for reducing memory usage, is valid when is_train=True flash_attention (bool): using flash attention, that cat support large sequence Returns: Tensor: the tensor after self attention """ assert ( query.dtype == key.dtype and key.dtype == value.dtype ), "types of qkv should be same" assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) if attn_mask == None: attn_mask = torch.Tensor() out = torch.ops.vacc.scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, is_train, recompute, is_causal, flash_attention, sm_scale, ) return out[0] scaled_dot_product_attention.apply = scaled_dot_product_attention def swiglu( x: torch.Tensor, ) -> torch.Tensor: r"""Perferms: x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] Args: x: Input tensor, support float/float16/bfloat16, 3 dims Return: Tensor: the out of swiglu """ return torch.ops.vacc.swiglu(x) swiglu.apply = swiglu def scaled_dot_product_attention_cp_forward( # same with torch define query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor = None, dropout_p: float = 0.5, is_causal: bool = False, # extend is_train: bool = True, ): r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d], kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d]. sq: sequence length of query sk: sequence length of key, value b: batch size h: head num d: dim of each head Args: query (Tensor): Input fp16/bfp16 tensor T is of shape [sq, b*h, d] key (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d] value (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d] attn_mask (Tensor): masked bool tensor of shape [1, sq, sk] dropout_p (float): the probability of dropout is_causal (bool): accelerate compute when mask is causal type scale (float): not use, default is 1/sqrt(dim) is_train (bool): train mode or eval mode Returns: list of tensor, size=4 [attention result, max of QK^t with shape of (b*h, sq, d)], sum of exp(QK^t) with shape of (b*h, sq, d), seed(used in backward)] """ assert ( query.dtype == key.dtype and key.dtype == value.dtype ), "types of qkv should be same" assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) assert ( is_causal or attn_mask != None ), "attn_mask should be valid when is_causal is False" if attn_mask is None: attn_mask = torch.Tensor() else: assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size( 0 ), "attn_mask size should be (..., sq, sk)" out = _torch_vacc.scaled_dot_product_attention_cp_forward( query, key, value, attn_mask, dropout_p, is_train, is_causal, ) return out def scaled_dot_product_attention_cp_backward( grad_output: torch.Tensor, attn_out: torch.Tensor, max_of_row: torch.Tensor, sum_of_row: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor], seed: torch.Tensor, dropout_p: float = 0.5, is_causal: bool = False, is_train: bool = True, ): assert ( query.dtype == key.dtype and key.dtype == value.dtype ), "types of qkv should be same" assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) assert ( is_causal or attn_mask != None ), "attn_mask should be valid when is_causal is False" if attn_mask == None: attn_mask = torch.Tensor() else: assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size( 0 ), "attn_mask size should be (..., sq, sk)" out = _torch_vacc.scaled_dot_product_attention_cp_backward( grad_output, attn_out, max_of_row, sum_of_row, query, key, value, attn_mask, dropout_p, is_train, is_causal, seed, ) return out def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, block_table: torch.Tensor, seq_len: torch.Tensor, sm_scale: float = -1, out: Optional[torch.Tensor] = None, ): r"""Performs: Apply attention operation for q tensor with fixed shape[batch, b*h, d], key_cache/value_cache with fixed shape[nb, bs, h, d], while the output is [batch, h, d]. batch: token batch nb: num_blocks bs: block_size h: heads num d: dim of each head Args: query (Tensor): Input fp16/bfp16 tensor T is of shape [batch, h, d] key_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d] value_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d] block_table (Tensor): k/v map of cache seq_len (Tensor): sequence lenth of each batch out (optional Tensor): output tensor if given, otherwise return a new tensor Returns: Tensor: the tensor after page attention """ return _torch_vacc.paged_attention( query, key_cache, value_cache, block_table, seq_len, out, sm_scale ) def reshape_and_cache_attention( src: torch.Tensor, cached: torch.Tensor, block_mapping: torch.Tensor, ): torch.ops.vacc.reshape_and_cache_attention(src, cached, block_mapping) def concat_and_cache_attention( src: torch.Tensor, src1: torch.Tensor, cached: torch.Tensor, block_mapping: torch.Tensor, ): _torch_vacc.concat_and_cache_attention(src, src1, cached, block_mapping) def w8a8_block_fp8_matmul( input: torch.Tensor, weight: torch.Tensor, input_scale: Optional[torch.Tensor], weight_scale: Optional[torch.Tensor], block_size: List[int], output: Optional[torch.Tensor] = None, **kwargs, ): return _torch_vacc.w8a8_block_fp8_matmul( input, weight, None, weight_scale, block_size, output, ) def w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, input_scale: Optional[torch.Tensor], weight_scale: Optional[torch.Tensor], block_size: List[int], output: Optional[torch.Tensor] = None, **kwargs, ): return _torch_vacc.w8a8_block_fp8_matmul( input, weight.T, None, weight_scale.T, [block_size[1], block_size[0]], output ) def moe_expert_token_group_reassign( topk_idx: torch.Tensor, topk_val: torch.Tensor, expert_num_: int, gp_size_: int = 16, gp_num_align_: int = 4, ): return _torch_vacc.moe_expert_token_group_reassign( topk_idx, topk_val, expert_num_, gp_size_, gp_num_align_ ) def fused_experts( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = True, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, decode_with_batch: bool = False, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: warning_message = "[fused_experts]:vacc only support fp8 weights now" assert a13_scale is None and a2_scale is None, warning_message assert use_fp8_w8a8, f"{warning_message}, but use_fp8_w8a8 is {use_fp8_w8a8}" assert ( w13_scale is not None and w2_scale is not None ), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}" assert ( block_shape is not None ), f"{warning_message}, but block_shape is {block_shape}" # assert ( # not decode_with_batch or hidden_states.size(0) <= 4 # ), "[fused_experts]:vacc only support batch <= 4 when decode" if hidden_states.device.type == "vacc": # topk weights dtype should be same with hidden_states topk_weights = topk_weights.to(hidden_states.dtype) # vacc device use int32 for experts_id topk_ids = topk_ids.to(torch.int32) hidden_dims, inter_dims = w13_weight.shape[1], w13_weight.shape[2] hidden_blocks, inter_blocks = w13_scale.shape[1], w13_scale.shape[2] block_size0, block_size1 = ( hidden_dims // hidden_blocks, inter_dims // inter_blocks, ) # assert ( # block_size0 == block_size1 # ), "quant block shape now support size0 == size1" return _torch_vacc.fused_experts( hidden_states, w13_weight, w2_weight, topk_weights, topk_ids, use_fp8_w8a8, w13_scale, w2_scale, a13_scale, a2_scale, [block_size0, block_size1], decode_with_batch, output_opt, ) from .custom_ops_cpu import fused_experts as fused_experts_default_method return fused_experts_default_method( hidden_states, w13_weight, w2_weight, topk_weights, topk_ids, use_fp8_w8a8, w13_scale, w2_scale, a13_scale, a2_scale, block_shape, decode_with_batch, ) # NOTE: w13, w2 using linear format def fused_mlp_fp8( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, use_fp8_w8a8: bool = True, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape_w13: Optional[List[int]] = None, block_shape_w2: Optional[List[int]] = None, output: Optional[torch.Tensor] = None, ): assert a13_scale is None assert a2_scale is None assert w13_scale is not None assert w2_scale is not None assert block_shape_w13 is not None assert block_shape_w2 is not None if hidden_states.device.type == "vacc": return _torch_vacc.fused_mlp( hidden_states, w13_weight, w2_weight, use_fp8_w8a8, w13_scale, w2_scale, a13_scale, a2_scale, block_shape_w13, block_shape_w2, output, ) from .custom_ops_cpu import fused_mlp_mm_fp8 as fused_mlp_mm_fp8_default_method return fused_mlp_mm_fp8_default_method( hidden_states, w13_weight.T, w2_weight.T, use_fp8_w8a8, w13_scale.T, w2_scale.T, a13_scale, a2_scale, list(block_shape_w13)[::-1], list(block_shape_w2)[::-1], ) def fused_mlp_mm_fp8( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, use_fp8_w8a8: bool = True, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape_w13: Optional[List[int]] = None, block_shape_w2: Optional[List[int]] = None, output: Optional[torch.Tensor] = None, ): return fused_mlp_fp8( hidden_states, w13_weight.T, w2_weight.T, use_fp8_w8a8, w13_scale.T, w2_scale.T, a13_scale, a2_scale, list(block_shape_w13)[::-1], list(block_shape_w2)[::-1], output, ) def fused_moe_preprocess( gating_output, bias, num_expert_group=8, num_limited_group=4, ): return _torch_vacc.fused_moe_preprocess( gating_output, bias, num_expert_group, num_limited_group ) def fused_residual_rmsnorm( input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, epsilon: float = 1e-6, output: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None, ): # out = _torch_vacc.fused_residual_rmsnorm(input, weight, residual, epsilon, inplace) # if len(out) == 1: # return out[0] # return out # TODO: VNNL support optional residual input = input.contiguous() weight = weight.contiguous() if residual is None: return rms_norm(input, weight, epsilon) else: return torch.ops.vacc.fused_residual_rmsnorm( input, weight, residual, epsilon, output, residual_out, ) def parallel_embedding( input: torch.Tensor, weight: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, added_vocab_start_index: int, added_vocab_end_index: int, output: Optional[torch.Tensor] = None, ): return _torch_vacc.parallel_embedding( input, weight, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, added_vocab_end_index, output, ) def fused_mla( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, W_Q: torch.Tensor, W_UK: torch.Tensor, W_QR: torch.Tensor, kv_a_proj_weight_scale_inv: torch.Tensor, kv_a_proj_weight: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, W_UV: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, o_proj_weight: torch.Tensor, q_a_proj_blocksize: Tuple[int] | List[int], kv_a_proj_blocksize: Tuple[int] | List[int], o_proj_blocksize: Tuple[int] | List[int], seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, ): # TODO: CHECK out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, W_Q, W_UK, W_QR, kv_a_proj_weight_scale_inv, kv_a_proj_weight, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, W_UV, o_proj_weight_scale_inv, o_proj_weight, q_a_proj_blocksize, kv_a_proj_blocksize, o_proj_blocksize, seq_lens, sm_scale, head_num, ) if out_single: return out[0] return out def fused_mla_v2( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, w_q: torch.Tensor, w_q_scale: torch.Tensor, w_uk: torch.Tensor, w_uk_scale: torch.Tensor, w_qr: torch.Tensor, w_qr_scale: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, w_uv: torch.Tensor, w_uv_scale: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool=False, ): # TODO: CHECk out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla_v2( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, ) if out_single: return out[0] return out def fused_mla_prefill_stage0( hidden_states: torch.Tensor, residual: torch.Tensor, hidden_states_norm_weight: torch.Tensor, qkv_a_proj_weight: torch.Tensor, qkv_a_proj_weight_scale_inv: torch.Tensor, ): # TODO: CHECk out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla_prefill_stage0( hidden_states, residual, hidden_states_norm_weight, qkv_a_proj_weight, qkv_a_proj_weight_scale_inv, ) if out_single: return out[0] return out def fused_mla_prefill_stage1( qkv_a: torch.Tensor, q_a_layernorm_weight: torch.Tensor, q_proj_weight: torch.Tensor, q_proj_weight_scale_inv: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, kv_b_proj_weight: torch.Tensor, kv_b_proj_weight_scale_inv: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, seq_lens_num: List[int], sm_scale: float, num_head: int, mla_out_tensor: Optional[torch.Tensor] = None ): out = _torch_vacc.fused_mla_prefill_stage1( qkv_a, q_a_layernorm_weight, q_proj_weight, q_proj_weight_scale_inv, kv_a_layernorm_weight, kv_b_proj_weight, kv_b_proj_weight_scale_inv, sin_cache, cos_cache, slot_mapping, kv_cache, o_proj_weight, o_proj_weight_scale_inv, seq_lens_num, sm_scale, num_head, mla_out_tensor ) return out def fused_mla_allreduce( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, W_Q: torch.Tensor, W_UK: torch.Tensor, W_QR: torch.Tensor, kv_a_proj_weight_scale_inv: torch.Tensor, kv_a_proj_weight: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, W_UV: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, o_proj_weight: torch.Tensor, q_a_proj_blocksize: Tuple[int] | List[int], kv_a_proj_blocksize: Tuple[int] | List[int], o_proj_blocksize: Tuple[int] | List[int], seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, red_op_type: int, world_size: int, rank: int, root_rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert red_op_type == 0, "all_reduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla_allreduce( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, W_Q, W_UK, W_QR, kv_a_proj_weight_scale_inv, kv_a_proj_weight, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, W_UV, o_proj_weight_scale_inv, o_proj_weight, q_a_proj_blocksize, kv_a_proj_blocksize, o_proj_blocksize, seq_lens, sm_scale, head_num, red_op_type, world_size, rank, root_rank, group_id, dev_info, ) if out_single: return out[0] return out def fused_mla_allreduce_v2( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, w_q: torch.Tensor, w_q_scale: torch.Tensor, w_uk: torch.Tensor, w_uk_scale: torch.Tensor, w_qr: torch.Tensor, w_qr_scale: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, w_uv: torch.Tensor, w_uv_scale: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert world_size > 0, "fused_mla_allreduce_v2 only support world_size > 0" assert rank >= 0, "fused_mla_allreduce_v2 only support rank >= 0" if not dev_info: dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla_allreduce_v2( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, world_size, rank, group_id, dev_info, ) if out_single: return out[0] return out def fused_mla_prefill_stage0_allreduce( hidden_states: torch.Tensor, residual: torch.Tensor, hidden_states_norm_weight: torch.Tensor, qkv_a_proj_weight: torch.Tensor, qkv_a_proj_weight_scale_inv: torch.Tensor, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECk out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fused_mla_prefill_stage0_allreduce( hidden_states, residual, hidden_states_norm_weight, qkv_a_proj_weight, qkv_a_proj_weight_scale_inv, world_size, rank, group_id, dev_info, ) if out_single: return out[0] return out def all_reduce( input: torch.Tensor, rank: int, world_size: int, group_id: int, dev_info: List[int], red_op_type: int = 0, ): assert input.device.type == "vacc", "all_reduce only support VACC" assert red_op_type == 0, "all_reduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.all_reduce( input, rank, world_size, 0, group_id, dev_info, red_op_type ) def all_gather( input: torch.Tensor, rank: int, world_size: int, group_id: int, dev_info: List[int], ): assert input.device.type == "vacc", "all_gather only support VACC" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.all_gather(input, rank, world_size, 0, group_id, dev_info, 0) def broadcast( input: torch.Tensor, rank: int, world_size: int, root_rank: int, group_id: int, dev_info: List[int], ): assert input.device.type == "vacc", "broadcast only support VACC" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.all_gather( input, rank, world_size, root_rank, group_id, dev_info, 1 ) def fused_mlp_moe_with_rmsnorm( hidden_states: torch.Tensor, rms_residual: torch.Tensor, rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, moe_weight_13: torch.Tensor, moe_weight_2: torch.Tensor, moe_weight_scale_13: torch.Tensor, moe_weight_scale_2: torch.Tensor, mm_weight: torch.Tensor, moe_bias: torch.Tensor, mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], moe_block_size_w13: List[int] | Tuple[int], moe_block_size_w2: List[int] | Tuple[int], ): return _torch_vacc.fused_mlp_moe_with_rmsnorm( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, moe_weight_13, moe_weight_2, moe_weight_scale_13, moe_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, moe_block_size_w13, moe_block_size_w2, ) def fused_mlp_with_rmsnorm( hidden_states: torch.Tensor, rms_residual: torch.Tensor, rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], ): return _torch_vacc.fused_mlp_with_rmsnorm( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mlp_block_size_w13, mlp_block_size_w2, ) def fuse_moe_decode_v2_allreduce( hidden_states: torch.Tensor, rms_residual: torch.Tensor, rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, moe_weight_13: torch.Tensor, moe_weight_2: torch.Tensor, moe_weight_scale_13: torch.Tensor, moe_weight_scale_2: torch.Tensor, mm_weight: torch.Tensor, moe_bias: torch.Tensor, mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], moe_block_size_w13: List[int] | Tuple[int], moe_block_size_w2: List[int] | Tuple[int], red_op_type: int, world_size: int, rank: int, root_rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): assert ( hidden_states.device.type == "vacc" ), "fuse_moe_decode_v2_allreduce only support VACC" assert red_op_type == 0, "fuse_moe_decode_v2_allreduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_moe_decode_v2_allreduce( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, moe_weight_13, moe_weight_2, moe_weight_scale_13, moe_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, moe_block_size_w13, moe_block_size_w2, red_op_type, world_size, rank, root_rank, group_id, dev_info, ) def topk_topp( logits: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ): return _torch_vacc.topk_topp(logits, p, k) def fused_mlp_allreduce( hidden_states: torch.Tensor, rms_residual: torch.Tensor, rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], red_op_type: int, world_size: int, rank: int, root_rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): assert hidden_states.device.type == "vacc", "fused_mlp_allreduce only support VACC" assert red_op_type == 0, "fused_mlp_allreduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fused_mlp_allreduce( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mlp_block_size_w13, mlp_block_size_w2, red_op_type, world_size, rank, root_rank, group_id, dev_info, ) def mla_matmul_scale( input: torch.Tensor, weight: torch.Tensor, scale: float, align_seq_len: int = 1024 ): return _torch_vacc.mla_matmul_scale( input, weight, scale, align_seq_len, ) def mla_matmul( input: torch.Tensor, weight: torch.Tensor, ): return _torch_vacc.mla_matmul( input, weight, ) def ds3_sampler( src, p, k, temperatures, exponential_enable, generator: Optional[Generator] = None )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: return _torch_vacc.ds3_sampler( src, p, k, temperatures, exponential_enable, generator ) def sampler_v1( src, p, k, temperatures, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]] = {} )->Tuple[torch.Tensor, torch.Tensor]: if not isinstance(generators, dict): raise TypeError(f"generator must be a dictionary, got {type(generators)}") return _torch_vacc.sampler_v1( src, p, k, temperatures, all_greedy, all_random, generators ) def apply_penalties( src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens, frequency_penalties, presence_penalties, is_first_calculation )->Tuple[torch.Tensor]: return _torch_vacc.apply_penalties(src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens, frequency_penalties, presence_penalties, is_first_calculation) def rejection_sampler( target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator: Optional[Generator] = None ): return _torch_vacc.rejection_sampler( target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator ) def rejection_sampler_update_hidden_states( hidden_states, accepted_index ): return _torch_vacc.rejection_sampler_update_hidden_states( hidden_states, accepted_index ) def rejection_sampler_v1( target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]] ): return _torch_vacc.rejection_sampler_v1( target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators ) def fused_matmul_allgather( input: torch.Tensor, mat2: torch.Tensor, world_size: int, rank: int, group_id: int, dev_info: List[int] = None, ) -> torch.Tensor: assert input.device.type == "vacc", "fused_matmul_allgather only support VACC" assert mat2.device.type == "vacc", "fused_matmul_allgather only support VACC" assert 2 == input.ndim, "fused_matmul_allgather: 'input' must be 2D tensor" assert 2 == mat2.ndim, "fused_matmul_allgather: 'mat2' must be 2D tensor" assert ( input.shape[-1] == mat2.shape[0] ), "fused_matmul_allgather: dim1 of 'input' must be equal to dim0 of 'mat2'" assert world_size > 0, "world_size must be greater than 0" if dev_info is None or 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fused_matmul_allgather( input, mat2, world_size, rank, group_id, dev_info, ) def fuse_moe_prefill_stage0( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, rms_hidden_state_opt: Optional[torch.Tensor] = None, mlp_hidden_state_opt: Optional[torch.Tensor] = None, topk_ids_opt: Optional[torch.Tensor] = None, topk_weight_opt: Optional[torch.Tensor] = None, ): return _torch_vacc.fuse_moe_prefill_stage0( hidden_states, rms_residual, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, rms_hidden_state_opt, mlp_hidden_state_opt, topk_ids_opt, topk_weight_opt, ) def fuse_mla_mlp_v2_allreduce_decode( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, w_q: torch.Tensor, w_q_scale: torch.Tensor, w_uk: torch.Tensor, w_uk_scale: torch.Tensor, w_qr: torch.Tensor, w_qr_scale: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, w_uv: torch.Tensor, w_uv_scale: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # mlp weight rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, # mlp params mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_mlp_v2_allreduce_decode only support VACC" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mlp_block_size_w13, mlp_block_size_w2, world_size, rank, group_id, dev_info, ) # if out_single: # return out[0] return out def fuse_mla_moe_v2_allreduce_decode( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight hidden_states_norm_weight: torch.Tensor, q_a_proj_weight: torch.Tensor, q_a_proj_weight_scale_inv: torch.Tensor, q_a_layernorm_weight: torch.Tensor, w_q: torch.Tensor, w_q_scale: torch.Tensor, w_uk: torch.Tensor, w_uk_scale: torch.Tensor, w_qr: torch.Tensor, w_qr_scale: torch.Tensor, kv_a_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, w_uv: torch.Tensor, w_uv_scale: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # moe weight rms_weight: torch.Tensor, mlp_weight_13: torch.Tensor, mlp_weight_2: torch.Tensor, mlp_weight_scale_13: torch.Tensor, mlp_weight_scale_2: torch.Tensor, moe_weight_13: torch.Tensor, moe_weight_2: torch.Tensor, moe_weight_scale_13: torch.Tensor, moe_weight_scale_2: torch.Tensor, mm_weight: torch.Tensor, moe_bias: torch.Tensor, # moe params mlp_block_size_w13: Tuple[int] | List[int], mlp_block_size_w2: Tuple[int] | List[int], moe_block_size_w13: Tuple[int] | List[int], moe_block_size_w2: Tuple[int] | List[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_moe_v2_allreduce_decode only support VACC" # assert red_op_type == 0, "all_reduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, moe_weight_13, moe_weight_2, moe_weight_scale_13, moe_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, moe_block_size_w13, moe_block_size_w2, world_size, rank, group_id, dev_info, ) # if out_single: # return out[0] return out # ! register fake function for cpp imple custom op @torch.library.register_fake("vacc::RotaryPosEmbedding_func") def _( q: torch.Tensor, k: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, offset: int = 0, mode: str = "neox", ): return [torch.empty_like(q), torch.empty_like(k)] @torch.library.register_fake("vacc::reshape_and_cache_attention") def _( src: torch.Tensor, cached: torch.Tensor, block_mapping: torch.Tensor, ): pass @torch.library.register_fake("vacc::scaled_dot_product_attention") def _( query, key, value, attn_mask, dropout_p, is_train, recompute, is_causal, flash_attention, sm_scale ): return [torch.empty(size=(query.size()[0], query.size()[1], key.size()[2]), device=query.device, dtype=query.dtype)] @torch.library.register_fake("vacc::rms_norm_func") def _( input: torch.Tensor, weight: torch.Tensor, eps: float, output: Optional[torch.Tensor] = None ) -> torch.Tensor: return torch.empty_like(input) @torch.library.register_fake("vacc::fused_residual_rmsnorm") def _( input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, epsilon: float = 1e-6, output: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None, ): return [torch.empty_like(input), torch.empty_like(input)] @torch.library.register_fake("vacc::swiglu") def _( self: torch.Tensor ): shape = list(self.shape) shape[-1] = shape[-1] // 2 return torch.empty(size=shape, dtype=self.dtype, device=self.device) def fuse_mla_mlp_v2_allreduce_decode_layers( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight hidden_states_norm_weight: List[torch.Tensor], q_a_proj_weight: List[torch.Tensor], q_a_proj_weight_scale_inv: List[torch.Tensor], q_a_layernorm_weight: List[torch.Tensor], w_q: List[torch.Tensor], w_q_scale: List[torch.Tensor], w_uk: List[torch.Tensor], w_uk_scale: List[torch.Tensor], w_qr: List[torch.Tensor], w_qr_scale: List[torch.Tensor], kv_a_layernorm_weight: List[torch.Tensor], sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: List[torch.Tensor], block_tables: torch.Tensor, block_group_size: int, w_uv: List[torch.Tensor], w_uv_scale: List[torch.Tensor], o_proj_weight: List[torch.Tensor], o_proj_weight_scale_inv: List[torch.Tensor], # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # mlp weight rms_weight: List[torch.Tensor], mlp_weight_13: List[torch.Tensor], mlp_weight_2: List[torch.Tensor], mlp_weight_scale_13: List[torch.Tensor], mlp_weight_scale_2: List[torch.Tensor], # mlp params mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_mlp_v2_allreduce_decode_layers only support VACC" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] if residual is None: residual = torch.Tensor() out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, mlp_block_size_w13, mlp_block_size_w2, world_size, rank, group_id, dev_info, ) return out def fuse_mla_mlp_v2_allreduce_decode_layers_v2( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: List[torch.Tensor], block_tables: torch.Tensor, block_group_size: int, # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # mlp weight # mlp params mlp_block_size_w13: List[int] | Tuple[int], mlp_block_size_w2: List[int] | Tuple[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_mlp_v2_allreduce_decode_layers_v2 only support VACC" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] if residual is None: residual = torch.Tensor() out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers( hidden_states, residual, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, seq_lens, sm_scale, head_num, flash_attention, mlp_block_size_w13, mlp_block_size_w2, world_size, rank, group_id, dev_info, ) return out def fuse_mla_moe_v2_allreduce_decode_layers( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight hidden_states_norm_weight: List[torch.Tensor], q_a_proj_weight: List[torch.Tensor], q_a_proj_weight_scale_inv: List[torch.Tensor], q_a_layernorm_weight: List[torch.Tensor], w_q: List[torch.Tensor], w_q_scale: List[torch.Tensor], w_uk: List[torch.Tensor], w_uk_scale: List[torch.Tensor], w_qr: List[torch.Tensor], w_qr_scale: List[torch.Tensor], kv_a_layernorm_weight: List[torch.Tensor], sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: List[torch.Tensor], block_tables: torch.Tensor, block_group_size: int, w_uv: List[torch.Tensor], w_uv_scale: List[torch.Tensor], o_proj_weight: List[torch.Tensor], o_proj_weight_scale_inv: List[torch.Tensor], # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # moe weight rms_weight: List[torch.Tensor], mlp_weight_13: List[torch.Tensor], mlp_weight_2: List[torch.Tensor], mlp_weight_scale_13: List[torch.Tensor], mlp_weight_scale_2: List[torch.Tensor], moe_weight_13: List[torch.Tensor], moe_weight_2: List[torch.Tensor], moe_weight_scale_13: List[torch.Tensor], moe_weight_scale_2: List[torch.Tensor], mm_weight: List[torch.Tensor], moe_bias: List[torch.Tensor], # moe params mlp_block_size_w13: Tuple[int] | List[int], mlp_block_size_w2: Tuple[int] | List[int], moe_block_size_w13: Tuple[int] | List[int], moe_block_size_w2: Tuple[int] | List[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_moe_v2_allreduce_decode_layers only support VACC" # assert red_op_type == 0, "all_reduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers( hidden_states, residual, hidden_states_norm_weight, q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_layernorm_weight, w_q, w_q_scale, w_uk, w_uk_scale, w_qr, w_qr_scale, kv_a_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, w_uv, w_uv_scale, o_proj_weight, o_proj_weight_scale_inv, seq_lens, sm_scale, head_num, flash_attention, rms_weight, mlp_weight_13, mlp_weight_2, mlp_weight_scale_13, mlp_weight_scale_2, moe_weight_13, moe_weight_2, moe_weight_scale_13, moe_weight_scale_2, mm_weight, moe_bias, mlp_block_size_w13, mlp_block_size_w2, moe_block_size_w13, moe_block_size_w2, world_size, rank, group_id, dev_info, ) if out_single: return out[0] return out def fuse_mla_moe_v2_allreduce_decode_layers_v2( # input hidden_states: torch.Tensor, residual: Optional[torch.Tensor], # mla weight sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: List[torch.Tensor], block_tables: torch.Tensor, block_group_size: int, # mla params seq_lens: Tuple[int] | List[int], sm_scale: float, head_num: int, flash_attention: bool, # moe params mlp_block_size_w13: Tuple[int] | List[int], mlp_block_size_w2: Tuple[int] | List[int], moe_block_size_w13: Tuple[int] | List[int], moe_block_size_w2: Tuple[int] | List[int], # vccl info world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): # TODO: CHECK assert ( hidden_states.device.type == "vacc" ), "fuse_mla_moe_v2_allreduce_decode_layers_v2 only support VACC" # assert red_op_type == 0, "all_reduce only support red_op_type=0" if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() # print("come to v2") out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers_v2( hidden_states, residual, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, seq_lens, sm_scale, head_num, flash_attention, mlp_block_size_w13, mlp_block_size_w2, moe_block_size_w13, moe_block_size_w2, world_size, rank, group_id, dev_info, ) if out_single: return out[0] return out def fuse_mlp_qwen_int4( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, scale13: torch.Tensor, scale2: torch.Tensor, zero13: torch.Tensor, zero2: torch.Tensor, block13: List[int] | Tuple[int], block2: List[int] | Tuple[int], engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp output_opt: Optional[torch.Tensor] = None ): assert engine_mode in [0, 1, 2] return _torch_vacc.fuse_mlp_qwen_int4( hidden_states, weight13, weight2, scale13, scale2, zero13, zero2, block13, block2, engine_mode, output_opt ) def fuse_mlp_qwen_int4_reduce( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, scale13: torch.Tensor, scale2: torch.Tensor, zero13: torch.Tensor, zero2: torch.Tensor, block13: List[int] | Tuple[int], block2: List[int] | Tuple[int], world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_mlp_qwen_int4_reduce( hidden_states, weight13, weight2, scale13, scale2, zero13, zero2, block13, block2, world_size, rank, group_id, dev_info, output_opt ) def fuse_mlp_qwen_fp8( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, scale13: torch.Tensor, scale2: torch.Tensor, zero13: torch.Tensor, zero2: torch.Tensor, block13: List[int] | Tuple[int], block2: List[int] | Tuple[int], engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp output_opt: Optional[torch.Tensor] = None ): assert engine_mode in [0, 1, 2] return _torch_vacc.fuse_mlp_qwen_int4( hidden_states, weight13, weight2, scale13, scale2, zero13, zero2, block13, block2, engine_mode, output_opt ) def fuse_mlp_qwen_fp8_reduce( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, scale13: torch.Tensor, scale2: torch.Tensor, zero13: torch.Tensor, zero2: torch.Tensor, block13: List[int] | Tuple[int], block2: List[int] | Tuple[int], world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_mlp_qwen_int4_reduce( hidden_states, weight13, weight2, scale13, scale2, zero13, zero2, block13, block2, world_size, rank, group_id, dev_info, output_opt ) def fuse_mlp_qwen_fp16_bf16( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, output_opt: Optional[torch.Tensor] = None ): return _torch_vacc.fuse_mlp_qwen_int4( hidden_states, weight13, weight2, torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), (0,0), (0,0), 0, output_opt ) def fuse_mlp_qwen_fp16_bf16_reduce( hidden_states: torch.Tensor, weight13: torch.Tensor, weight2: torch.Tensor, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_mlp_qwen_int4_reduce( hidden_states, weight13, weight2, torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor(), (0,0), (0,0), world_size, rank, group_id, dev_info, output_opt ) def w4a8_block_int4_matmul( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, dequant_block: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None, ): # TODO: CHECK out = _torch_vacc.w4a8_block_int4_matmul( input, weight, weight_scale, dequant_block, output_opt, ) return out def fuse_atten_qwen3( hidden_states: torch.Tensor, residual: torch.Tensor, hidden_states_norm_weight: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_weight_scale: torch.Tensor, qkv_proj_bias: torch.Tensor, qkv_proj_qzeros: torch.Tensor, q_layernorm_weight: torch.Tensor, k_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, o_proj_weight: torch.Tensor, o_proj_weight_scale: torch.Tensor, o_proj_bias: torch.Tensor, o_proj_qzeros: torch.Tensor, seq_lens: Tuple[int] | List[int], sm_scale: float, num_attention_heads: int, num_key_value_heads: int, flash_attention: bool, is_decode: bool, reduce_result: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None, res_opt: Optional[torch.Tensor] = None, ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fuse_atten_qwen3( hidden_states, residual, hidden_states_norm_weight, qkv_proj_weight, qkv_proj_weight_scale, none2empty(qkv_proj_bias), none2empty(qkv_proj_qzeros), q_layernorm_weight, k_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, o_proj_weight, o_proj_weight_scale, none2empty(o_proj_bias), none2empty(o_proj_qzeros), seq_lens, sm_scale, num_attention_heads, num_key_value_heads, flash_attention, is_decode, reduce_result, world_size, rank, group_id, dev_info, output_opt, res_opt, ) if out_single: return out[0] return out def fuse_atten_vit( hidden_states: torch.Tensor, hidden_states_norm_weight: torch.Tensor, hidden_states_norm_bias: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_bias: torch.Tensor, sin_cache: torch.Tensor, cos_cache: torch.Tensor, o_proj_weight: torch.Tensor, o_proj_bias: torch.Tensor, seq_lens: Tuple[int] | List[int], sm_scale: float, num_attention_heads: int, flash_attention: bool, reduce_result: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output_opt: Optional[torch.Tensor] = None, ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor out = _torch_vacc.fuse_atten_vit( hidden_states, hidden_states_norm_weight, hidden_states_norm_bias, qkv_proj_weight, none2empty(qkv_proj_bias), sin_cache, cos_cache, o_proj_weight, none2empty(o_proj_bias), seq_lens, sm_scale, num_attention_heads, flash_attention, reduce_result, world_size, rank, group_id, dev_info, output_opt ) return out def mrope_get_sin_cos( cos_cache: torch.Tensor, sin_cache: torch.Tensor, positions: torch.Tensor, mrope_section: List[int] | Tuple[int], mrope_interleaved: bool ): return _torch_vacc.mrope_get_sin_cos( cos_cache, sin_cache, positions, mrope_section, mrope_interleaved ) # NOTE: m <= 8, dsp version # if m > 8, call w4a8_block_int4_matmul def w4a8_block_int4_linear( input: torch.Tensor, weight: torch.Tensor, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, pack_factor: int = 8, output_opt: Optional[torch.Tensor] = None, ): # TODO: CHECK out = _torch_vacc.w4a8_block_int4_linear( input, weight, input_scale, weight_scale, bias, pack_factor, output_opt, ) return out def fuse_atten_qwen2( history_states: torch.Tensor, residual: torch.Tensor, hidden_states_norm_weight: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_weight_scale_inv: torch.Tensor, qkv_proj_bias: torch.Tensor, qkv_proj_qzeros: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, o_proj_bias: torch.Tensor, o_proj_qzeros: torch.Tensor, seq_lens_num: Tuple[int] | List[int], sm_scale: float, num_attention_heads: int, num_key_value_heads: int, flash_attentiton: bool, is_decode: bool, reduce_result: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] out_single = False if residual is None: out_single = True residual = torch.Tensor() out = _torch_vacc.fuse_atten_qwen2( history_states, residual, hidden_states_norm_weight, qkv_proj_weight, qkv_proj_weight_scale_inv, qkv_proj_bias, qkv_proj_qzeros, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, o_proj_weight, o_proj_weight_scale_inv, o_proj_bias, o_proj_qzeros, seq_lens_num, sm_scale, num_attention_heads, num_key_value_heads, flash_attentiton, is_decode, reduce_result, world_size, rank, group_id, dev_info, ) if out_single: return out[0] return out def qwen3_fuse_attention_moe_decode( # attention hidden_states: torch.Tensor, residual: torch.Tensor, hidden_states_norm_weight: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_weight_scale_inv: torch.Tensor, qkv_proj_bias: torch.Tensor, qkv_proj_qzeros: torch.Tensor, q_layernorm_weight: torch.Tensor, k_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, o_proj_weight: torch.Tensor, o_proj_weight_scale_inv: torch.Tensor, o_proj_bias: torch.Tensor, o_proj_qzeros: torch.Tensor, seq_lens_num: Tuple[int] | List[int], sm_scale: float, num_attention_heads: int, num_key_value_heads: int, flash_attentiton: bool, is_decode: bool, reduce_result: bool, # moe rms_weight: torch.Tensor, moe_weight_13: torch.Tensor, moe_weight_2: torch.Tensor, moe_weight_13_dequat: torch.Tensor, moe_weight_2_dequant: torch.Tensor, gate_weight: torch.Tensor, block_size_13: Tuple[int] | List[int], block_size_2: Tuple[int] | List[int], # dist world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor return _torch_vacc.qwen3_fuse_attention_moe_decode( hidden_states, none2empty(residual), hidden_states_norm_weight, qkv_proj_weight, qkv_proj_weight_scale_inv, none2empty(qkv_proj_bias), none2empty(qkv_proj_qzeros), q_layernorm_weight, k_layernorm_weight, sin_cache, cos_cache, slot_mapping, kv_cache, block_tables, block_group_size, o_proj_weight, o_proj_weight_scale_inv, none2empty(o_proj_bias), none2empty(o_proj_qzeros), seq_lens_num, sm_scale, num_attention_heads, num_key_value_heads, flash_attentiton, is_decode, reduce_result, rms_weight, moe_weight_13, moe_weight_2, moe_weight_13_dequat, moe_weight_2_dequant, gate_weight, block_size_13, block_size_2, world_size, rank, group_id, dev_info, ) def fuse_mtp_stage0( inputs_embeds: torch.Tensor, previous_hidden_states: torch.Tensor, positions: torch.Tensor, enorm_wegiht: torch.Tensor, hnorm_wegint: torch.Tensor, epsilon: float, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output: Optional[torch.Tensor] = None, ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_mtp_stage0( inputs_embeds, previous_hidden_states, positions, enorm_wegiht, hnorm_wegint, epsilon, world_size, rank, group_id, dev_info, output, ) def fuse_mtp_allreduce( inputs_embeds: torch.Tensor, previous_hidden_states: torch.Tensor, positions: torch.Tensor, enorm_wegiht: torch.Tensor, hnorm_wegint: torch.Tensor, linear_weight: torch.Tensor, epsilon: float, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], output: Optional[torch.Tensor] = None, ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_mtp_allreduce( inputs_embeds, previous_hidden_states, positions, enorm_wegiht, hnorm_wegint, linear_weight, epsilon, world_size, rank, group_id, dev_info, output, ) def roll_out( self: torch.Tensor, shifts: List[int] | Tuple[int] | int, dims: List[int] | Tuple[int] | int = [], output: Optional[torch.Tensor] = None, ): if isinstance(dims, int): dims = [dims] if isinstance(shifts, int): shifts = [shifts] return _torch_vacc.roll_out( self, shifts, dims, output, ) def fused_experts_int4_prefill( hidden_states: torch.Tensor, w13_weight: torch.Tensor, w2_weight: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, w13_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a13_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, w13_block_shape: Optional[List[int]] = None, w2_block_shape: Optional[List[int]] = None, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: warning_message = "[fused_experts]:vacc only support fp8 weights now" assert a13_scale is None and a2_scale is None, warning_message assert ( w13_scale is not None and w2_scale is not None ), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}" assert ( w13_block_shape is not None ), f"{warning_message}, but block_shape is {w13_block_shape}" # assert ( # not decode_with_batch or hidden_states.size(0) <= 4 # ), "[fused_experts]:vacc only support batch <= 4 when decode" # topk weights dtype should be same with hidden_states topk_weights = topk_weights.to(hidden_states.dtype) # vacc device use int32 for experts_id topk_ids = topk_ids.to(torch.int32) # assert ( # block_size0 == block_size1 # ), "quant block shape now support size0 == size1" return _torch_vacc.fused_experts_int4_prefill( hidden_states, w13_weight, w2_weight, topk_weights, topk_ids, w13_scale, w2_scale, a13_scale, a2_scale, w13_block_shape, w2_block_shape, output_opt, ) def fuse_bge_embedding_stage1( input_embeds: torch.Tensor, positions_ids: torch.Tensor, positions_embeddings_weight: torch.Tensor, token_type_ids: torch.Tensor, token_type_embeddings_weight: torch.Tensor, layernorm_weight: torch.Tensor, layernorm_bias: torch.Tensor, epsilon: float, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: return _torch_vacc.fuse_bge_embedding_stage1( input_embeds, positions_ids, positions_embeddings_weight, token_type_ids, token_type_embeddings_weight, layernorm_weight, layernorm_bias, epsilon, output_opt, ) def l2_norm( input: torch.Tensor, epsilon: float, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: return _torch_vacc.l2_norm( input, epsilon, output_opt, ) class BERT_ATTN_STAGE(Enum): # with reduce, for seqLen <= 2k FullStage = 0 # without reduce, call reduce outer AttnOutStage = 1 InterOutStage = 2 def fused_attn_bert_allreduce( hidden_states: torch.Tensor, qkv_weight: Optional[torch.Tensor] = None, qkv_bias: Optional[torch.Tensor] = None, self_weight: Optional[torch.Tensor] = None, self_bias: Optional[torch.Tensor] = None, self_norm_weight: Optional[torch.Tensor] = None, self_norm_bias: Optional[torch.Tensor] = None, intermediate_weight: Optional[torch.Tensor] = None, intermediate_bias: Optional[torch.Tensor] = None, output_weight: Optional[torch.Tensor] = None, output_bias: Optional[torch.Tensor] = None, output_norm_weight: Optional[torch.Tensor] = None, output_norm_bias: Optional[torch.Tensor] = None, dense_out: Optional[torch.Tensor] = None, seqs: List[int] | Tuple[int] = [], vnnlBertKind: BERT_ATTN_STAGE = BERT_ATTN_STAGE.FullStage, sm_scale: float = 1.0, num_q_heads: int = 1, num_kv_heads: int = 1, flash_attention: bool = False, reduce_result: bool = False, world_size: int = 1, rank: int = 0, group_id: int = 0, dev_info: List[int] | Tuple[int] = [], ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fused_attn_bert_allreduce( hidden_states, qkv_weight, qkv_bias, self_weight, self_bias, self_norm_weight, self_norm_bias, intermediate_weight, intermediate_bias, output_weight, output_bias, output_norm_weight, output_norm_bias, dense_out, seqs, vnnlBertKind.value, sm_scale, num_q_heads, num_kv_heads, flash_attention, reduce_result, world_size, rank, group_id, dev_info, ) def fuse_mlp_vision( src: torch.Tensor, weights_13: torch.Tensor, weights_2: torch.Tensor, weights_13_bias: Optional[torch.Tensor] = None, weights_2_bias: Optional[torch.Tensor] = None, act_type: int = 0, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: if weights_13_bias is None: weights_13_bias = torch.Tensor() if weights_2_bias is None: weights_2_bias = torch.Tensor() return _torch_vacc.fuse_mlp_vision( src, weights_13, weights_2, weights_13_bias, weights_2_bias, act_type, output_opt, ) def patch_merger_vision( src: torch.Tensor, weights_13: torch.Tensor, weights_2: torch.Tensor, weights_13_bias: Optional[torch.Tensor] = None, weights_2_bias: Optional[torch.Tensor] = None, act_type: int = 0, output_opt: Optional[torch.Tensor] = None, ) -> torch.Tensor: if weights_13_bias is None: weights_13_bias = torch.Tensor() if weights_2_bias is None: weights_2_bias = torch.Tensor() return _torch_vacc.patch_merger_vision( src, weights_13, weights_2, weights_13_bias, weights_2_bias, act_type, output_opt, )