import torch import unittest import torch_mlu_ops as ops from common_utils import * from typing import Union, List, Tuple from torch.nn.parameter import Parameter from torch.nn import functional as F from torch_mlu import mlu class TestFusedMOEOp(BtTestCase): def run_gen_case(self, dic): dump_data = dic.pop('dump_data') if dump_data: self.launch(*dic.values()) else: hidden_states = create_tensor_from_dic(dic['hidden_states']) gating_output = create_tensor_from_dic(dic['gating_output']) w1 = create_tensor_from_dic(dic['w1']) w2 = create_tensor_from_dic(dic['w2']) if dic['w1']['dtype'] is not torch.int8: w1 *= 0.1 w2 *= 0.1 bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1']) bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2']) residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual']) input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05) act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05) w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05) w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05) topk = dic['topk']['data'] renormalize = dic['renormalize']['data'] gated = dic['gated']['data'] act_mode = dic['act_mode']['data'] start_expert_id = dic['start_expert_id']['data'] block_n = dic['block_n']['data'] cncl_comm = dic['cncl_comm']['data'] w1_quant_flag = dic['w1_quant_flag']['data'] w2_quant_flag = dic['w2_quant_flag']['data'] self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag) def launch(self, *args): base = self.op_impl_base(*args) tmo_res = tmo.fused_moe(*args) self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True) def op_impl_base(self, *args): hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \ act_smooth, w1_scale, w2_scale, topk, renormalize, \ gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args if w2.dim() == 4: w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1)) expert_num = gating_output.size(-1) expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) def router(hidden_states, router_logit): router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1) topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1) if renormalize: topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1) sorted_expert_id, indices = expert_id.int().flatten().sort() nk = indices.size(0) token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu() expand_idx = indices.int() // topk combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu") combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu")) input_expand = hidden_states[expand_idx] input_list = input_expand.split(tuple(token_cout)) input_scale_list = [] input_split_result = [] if input_smooth is not None: # do smooth quant on input idx = 0 for i in range(expert_num): if i >= start_expert_id and i < start_expert_id + expert_size: if (input_list[i].size(0) > 0): temp = input_list[i] * input_smooth[idx] input_split_result.append(temp) idx += 1 else: input_split_result.append(input_list[i]) quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8) input_list = quant_input.split(token_cout.tolist()) input_scale_list = input_scale.split(token_cout.tolist()) return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx dtype = hidden_states.dtype if w1_quant_flag is None: inner_size = w1.size(1) // 2 if gated else w1.size(1) else: inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2) hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2) input = hidden_states.view(-1, hidden_states.size(-1)) gating_output = gating_output.view(-1, gating_output.size(-1)) input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output) output_list = [] idx = 0 need_quant = len(input_scale_list) != 0 if need_quant and w1_scale.dim() == 3: w1_scale = w1_scale.transpose(0, 1).contiguous() w2_scale = w2_scale.transpose(0, 1).contiguous() if w1_quant_flag is not None: w1_quant_group = w1_scale.size(1) quant_wise = hidden_size // w1_quant_group w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group) w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated) w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0) if w2_quant_flag is not None: w2_quant_group = w2_scale.size(1) quant_wise = inner_size // w2_quant_group w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group) w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0) for i in range(expert_num): # start_expert_id : start_expert_id + expert_size if i >= start_expert_id and i < start_expert_id + expert_size: if (input_list[i].size(0) > 0): if need_quant: if w1_quant_flag is None: gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i], w1[idx], w1_scale[idx], dtype) else: gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i], w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype, quant_flag = w1_quant_flag[idx]) else: gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0)) act_in = gemm1_out[:, :inner_size].float() gate = gemm1_out[:, inner_size:] act = act_mode_dict[act_mode] gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \ act(act_in).to(dtype=dtype) * gate if need_quant: quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8) if w2_quant_flag is None: gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype) else: gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale, w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype, quant_flag = w2_quant_flag[idx]) else: gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0)) output_list.append(gemm2_out) idx += 1 else: output_list.append(torch.zeros_like(input_list[i])) output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype) if residual is not None: output = output + residual.view(input.shape) return output.view(hidden_states.shape) def test_fused_moe(self): print("test_fused_moe") batch, seq, hidden_size, inner_size = 3, 5, 512, 768 dtype_list = [torch.half] if mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) for dtype in dtype_list: expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) start_expert_id_list = [0, 1, 3, 4, 5, 6] expert_size_list = [8, 4, 3, 2, 3, 2] for i in range(len(start_expert_id_list)): start_expert_id = start_expert_id_list[i] expert_size = expert_size_list[i] print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) torch_out = self.op_impl_base(hidden_states, router_logit, weight1[start_expert_id:start_expert_id+expert_size], weight2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) # (N, T, C) tmo_out_1 = ops.fused_moe(hidden_states, router_logit, weight1[start_expert_id:start_expert_id+expert_size], weight2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id) # (N*T, C) tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), weight1[start_expert_id:start_expert_id+expert_size], weight2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id).reshape(batch, seq, hidden_size) tmo_out_3 = fused_moe(hidden_states, router_logit, weight1[start_expert_id:start_expert_id+expert_size], weight2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True) def test_pertoken_quant_fused_moe_tp(self): print("test_pertoken_quant_fused_moe_tp") batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 if not mlu.is_bf16_supported(): data_type = torch.float16 print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode) def test_moe_tp2_mixed_ep4_no_quant(self): print("test_moe_tp2_mixed_ep4_no_quant") tp_num = 2 ep_num = 4 expert_num = 32 expert_size = expert_num // ep_num batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16 if not mlu.is_bf16_supported(): data_type = torch.float16 hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) residual = None weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size) w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) tmo_out1 = torch.zeros_like(hidden_states) tmo_out2 = torch.zeros_like(hidden_states) torch_out = torch.zeros_like(hidden_states) for tp_idx in range(tp_num): w1_curr_tp = w1[:, tp_idx, ...] w2_curr_tp = w2[:, :, tp_idx, :] for ep_idx in range(ep_num): start_expert_id = ep_idx * expert_size w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous() w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous() tmo_out1 += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep, bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id) tmo_out2 += fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep, bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id) torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep, bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size) self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True) def test_moe_tp2_mixed_ep4_quant(self): print("test_moe_tp2_mixed_ep4_quant") tp_num = 2 ep_num = 4 expert_num = 32 expert_size = expert_num // ep_num batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16 if not mlu.is_bf16_supported(): data_type = torch.float16 scale_s = 0.1 # avoid the occurrence of inf eps = 0.01 # Avoid the occurrence of nan hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) residual = None weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps weight1_shape, weight2_shape = weight1.shape, weight2.shape weight1 = weight1 / input_smooth.unsqueeze(1) weight2 = weight2 / act_smooth.unsqueeze(1) quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8) quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8) quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape) w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) torch_out = torch.zeros_like(hidden_states) for _ in range(tp_num): w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num) act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num) w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size) w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num) input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size) act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num) w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num) w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) tmo_out = torch.zeros_like(hidden_states) tmo_out1 = torch.zeros_like(hidden_states) torch_out = torch.zeros_like(hidden_states) for tp_idx in range(tp_num): w1_curr_tp = w1[:, :, tp_idx, ...] w2_curr_tp = w2[:, :, :, tp_idx, :] act_smooth_tp = act_smooth[:, :, tp_idx, :] w1_scale_tp = w1_scale[:, :, tp_idx, :] for ep_idx in range(ep_num): start_expert_id = ep_idx * expert_size w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous() w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous() input_smooth_curr_ep = input_smooth[ep_idx].contiguous() act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous() w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous() w2_scale_curr_ep = w2_scale[ep_idx].contiguous() tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size] router_logit, # [batch, seq, expert_num] w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size] w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num] bias1, bias2, residual, input_smooth_curr_ep, # [expert_size, hidden_size] act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num] w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num] w2_scale_curr_ep, # [expert_size, hidden_size] topk, renormalize, gated, act_mode, start_expert_id) tmo_out1 += fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep, bias1, bias2, residual, input_smooth_curr_ep, act_smooth_curr_ep, w1_scale_curr_ep, w2_scale_curr_ep, topk, renormalize, gated, act_mode, start_expert_id) torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep, bias1, bias2, residual, input_smooth_curr_ep, act_smooth_curr_ep, w1_scale_curr_ep, w2_scale_curr_ep, topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True) self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True) def test_smq_fused_moe_random_tp(self): print("test_smq_fused_moe_random_tp") import random random.seed(0) act_mode = 'gelu' case_list = set() for i in range(10): batch = random.randint(1, 10) seq = random.randint(1, 10) hidden_size = random.randrange(512, 2048, 2) inner_size = random.randrange(512, 2048, 2) expert_num = random.randint(1, 40) topk = random.randint(1,expert_num) gated = random.choice([True, False]) renormalize = random.choice([True, False]) data_type = random.choice([torch.bfloat16, torch.float16]) if not mlu.is_bf16_supported(): data_type = torch.float16 case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode) if case in case_list: continue case_list.add(case) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) self.__run_sq_case(*case) def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1): if expert_size == -1: expert_size = expert_num scale_s = 0.01 # avoid the occurrence of inf eps = 0.1 # Avoid the occurrence of nan hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps bias1, bias2 = None, None weight1_shape, weight2_shape = weight1.shape, weight2.shape weight1 = weight1 / input_smooth.unsqueeze(1) weight2 = weight2 / act_smooth.unsqueeze(1) quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], w1_scale[start_expert_id:start_expert_id+expert_size], w2_scale[start_expert_id:start_expert_id+expert_size], topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) # (N, T, C) tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], w1_scale[start_expert_id:start_expert_id+expert_size], w2_scale[start_expert_id:start_expert_id+expert_size], topk, renormalize, gated, act_mode, start_expert_id) # # (N*T, C) tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], w1_scale[start_expert_id:start_expert_id+expert_size], w2_scale[start_expert_id:start_expert_id+expert_size], topk, renormalize, gated, act_mode, start_expert_id).view(batch, seq, hidden_size) tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], w1_scale[start_expert_id:start_expert_id+expert_size], w2_scale[start_expert_id:start_expert_id+expert_size], topk, renormalize, gated, act_mode, start_expert_id).view(batch, seq, hidden_size) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True) self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) def test_fused_moe_with_4D_w2(self): print("test_fused_moe_with_4D_w2") batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 dtype_list = [torch.half] if mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) for dtype in dtype_list: expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) torch_out = self.op_impl_base(hidden_states, router_logit, weight1, weight2, bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, 0, 0, 0, None, None) # (N, T, C) tmo_out_1 = ops.fused_moe(hidden_states, router_logit, weight1, weight2, bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, 0) # (N*T, C) tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(), bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, None, None, None, None, topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) def test_moe_tp2_mixed_ep4_with_4D_w2(self): print("test_moe_tp2_mixed_ep4_with_4D_w2") tp_num, ep_num = 2, 4 batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048 expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 assert inner_size % tp_num == 0 assert expert_num % ep_num == 0 expert_size = expert_num // ep_num assert 4096 % inner_size == 0 block_e = 4096 // inner_size assert expert_size % block_e == 0 if not mlu.is_bf16_supported(): data_type = torch.float16 hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) residual = None bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size) w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) tmo_out = torch.zeros_like(hidden_states) torch_out = torch.zeros_like(hidden_states) for tp_idx in range(tp_num): new_inner_size = w2.shape[-1] w1_curr_tp = w1[:, tp_idx, ...] w2_curr_tp = w2[:, :, tp_idx, :] for ep_idx in range(ep_num): start_expert_id = ep_idx * expert_size w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous() w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous() torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep, w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), bias1, bias2, residual, None, None, None, None, topk, renormalize, gated, act_mode, start_expert_id) self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) def test_sq_fused_moe_with_4D_w2(self): print("test_sq_fused_moe_with_4D_w2") batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 if not mlu.is_bf16_supported(): data_type = torch.float16 assert 4096 % inner_size == 0 block_e = 4096 // inner_size assert expert_num % block_e == 0 scale_s = 0.01 # avoid the occurrence of inf eps = 0.1 # Avoid the occurrence of nan hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) bias1, bias2 = None, None input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps weight1_shape, weight2_shape = weight1.shape, weight2.shape weight1 = weight1 / input_smooth.unsqueeze(1) weight2 = weight2 / act_smooth.unsqueeze(1) quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None) # (N, T, C) tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0) # # (N*T, C) tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), quant_w1, quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(), bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True) def test_sq_fused_moe_random_tp_quant_grouped(self): print("test_sq_fused_moe_random_tp_quant_grouped") import random random.seed(0) act_mode = 'gelu' case_list = set() while (len(case_list) < 100): batch = random.randint(1, 10) seq = random.randint(1, 10) hidden_size = random.randrange(512, 2048, 128) inner_size = random.randrange(512, 2048, 512) expert_num = random.randint(1, 40) topk = random.randint(1, expert_num) gated = random.choice([True, False]) renormalize = random.choice([True, False]) quant_bit = random.choice([4, 8]) data_type = random.choice([torch.bfloat16, torch.float16]) if not mlu.is_bf16_supported(): data_type = torch.float16 case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode) if case in case_list: continue case_list.add(case) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True) self.__run_quant_grouped_case(*case) def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128): def get_quant_group(n, quant_group_size): quant_group = n // quant_group_size return quant_group if quant_group >= 1 else 1 if expert_size == -1: expert_size = expert_num scale_s = 0.01 # avoid the occurrence of inf eps = 0.1 # Avoid the occurrence of nan hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps bias1, bias2 = None, None weight1_shape, weight2_shape = weight1.shape, weight2.shape weight1 = weight1 / input_smooth.unsqueeze(1) weight2 = weight2 / act_smooth.unsqueeze(1) w1_quant_group = get_quant_group(hidden_size, quant_group_size) w2_quant_group = get_quant_group(inner_size, quant_group_size) quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group) quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group) quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) if quant_bit == 4: quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2) w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous() w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous() # split scale and transpose def extract_scale(scale, start_expert, expert_size): return scale[:, start_expert:start_expert+expert_size, :].contiguous() torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], extract_scale(w1_scale, start_expert_id, expert_size), extract_scale(w2_scale, start_expert_id, expert_size), topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) # (N, T, C) tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], extract_scale(w1_scale, start_expert_id, expert_size), extract_scale(w2_scale, start_expert_id, expert_size), topk, renormalize, gated, act_mode, start_expert_id) tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size), router_logit.view(-1, expert_num), quant_w1[start_expert_id:start_expert_id+expert_size], quant_w2[start_expert_id:start_expert_id+expert_size], bias1, bias2, residual.view(-1, hidden_size) if residual is not None else None, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], extract_scale(w1_scale, start_expert_id, expert_size), extract_scale(w2_scale, start_expert_id, expert_size), topk, renormalize, gated, act_mode, start_expert_id).view(batch, seq, hidden_size) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1): if expert_size == -1: expert_size = expert_num w1_quant_group = hidden_size // quant_wise w2_quant_group = inner_size // quant_wise w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4 w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4 w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated) w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8) w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8) hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) bias1, bias2 = None, None w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated) w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0) w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0) # split scale and transpose def extract_scale(scale, start_expert, expert_size): return scale[:, start_expert:start_expert+expert_size, :].contiguous() params = [hidden_states, router_logit, w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]], w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]], bias1, bias2, residual, input_smooth[start_expert_id:start_expert_id+expert_size], act_smooth[start_expert_id:start_expert_id+expert_size], extract_scale(w1_scale, start_expert_id, expert_size), extract_scale(w2_scale, start_expert_id, expert_size), topk, renormalize, gated, act_mode, start_expert_id, 0, 0, w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(), w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()] torch_out = self.op_impl_base(*params) # (N, T, C) tmo_out_1 = ops.fused_moe(*params) tmo_out_2 = fused_moe(*params) self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True) self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True) def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self): print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed") import random random.seed(0) act_mode = 'gelu' case_list = set() while (len(case_list) < 100): batch = random.randint(1, 10) seq = random.randint(1, 10) hidden_size = random.randrange(1024, 3072, 512) inner_size = random.randrange(1024, 3072, 512) expert_num = random.randint(1, 40) topk = random.randint(1, expert_num) gated = random.choice([True, False]) renormalize = random.choice([True, False]) quant_wise = random.choice([128, 256, 512]) data_type = random.choice([torch.bfloat16, torch.float16]) if not mlu.is_bf16_supported(): data_type = torch.float16 case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode) if case in case_list: continue case_list.add(case) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True) self.__run_w4w8_mixed_case(*case, 0, -1) def test_sq_fused_moe_single_quant_group(self): print("test_sq_fused_moe_single_quant_group") import random random.seed(0) act_mode = 'gelu' batch = 9 seq = 10 hidden_size = 1664 inner_size = 512 expert_num = 15 topk = 2 gated = False renormalize = False quant_bit = 4 data_type = torch.float16 case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode) print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True) self.__run_quant_grouped_case(*case) def test_inductor(self): batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16 start_expert_id, expert_size = 0, 8 hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type) router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32) residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type) weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) args = (hidden_states, router_logit, weight1[start_expert_id:start_expert_id+expert_size], weight2[start_expert_id:start_expert_id+expert_size], None, None, residual, None, None, None, None, None, None, topk, renormalize, gated, act_mode, 0, 0, 0) self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args) eps = 1e-5 input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps weight1_shape, weight2_shape = weight1.shape, weight2.shape weight1 = weight1 / input_smooth.unsqueeze(1) weight2 = weight2 / act_smooth.unsqueeze(1) quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) args = (hidden_states, router_logit, quant_w1, quant_w2, None, None, residual, input_smooth, act_smooth, w1_scale, w2_scale, None, None, topk, renormalize, gated, act_mode, 0, 0, 0) self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args) if __name__ == '__main__': exit(run_unittest(TestFusedMOEOp))