Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py
2026-02-04 17:39:32 +08:00

947 lines
58 KiB
Python
Executable File

import torch
import unittest
import torch_mlu_ops as ops
from common_utils import *
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch_mlu import mlu
class TestFusedMOEOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
hidden_states = create_tensor_from_dic(dic['hidden_states'])
gating_output = create_tensor_from_dic(dic['gating_output'])
w1 = create_tensor_from_dic(dic['w1'])
w2 = create_tensor_from_dic(dic['w2'])
if dic['w1']['dtype'] is not torch.int8:
w1 *= 0.1
w2 *= 0.1
bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1'])
bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2'])
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05)
act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05)
w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05)
w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05)
topk = dic['topk']['data']
renormalize = dic['renormalize']['data']
gated = dic['gated']['data']
act_mode = dic['act_mode']['data']
start_expert_id = dic['start_expert_id']['data']
block_n = dic['block_n']['data']
cncl_comm = dic['cncl_comm']['data']
w1_quant_flag = dic['w1_quant_flag']['data']
w2_quant_flag = dic['w2_quant_flag']['data']
self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth,
act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode,
start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag)
def launch(self, *args):
base = self.op_impl_base(*args)
tmo_res = tmo.fused_moe(*args)
self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True)
def op_impl_base(self, *args):
hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \
act_smooth, w1_scale, w2_scale, topk, renormalize, \
gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args
if w2.dim() == 4:
w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1))
expert_num = gating_output.size(-1)
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
def router(hidden_states, router_logit):
router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1)
if renormalize:
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
sorted_expert_id, indices = expert_id.int().flatten().sort()
nk = indices.size(0)
token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu()
expand_idx = indices.int() // topk
combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu")
combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu"))
input_expand = hidden_states[expand_idx]
input_list = input_expand.split(tuple(token_cout))
input_scale_list = []
input_split_result = []
if input_smooth is not None:
# do smooth quant on input
idx = 0
for i in range(expert_num):
if i >= start_expert_id and i < start_expert_id + expert_size:
if (input_list[i].size(0) > 0):
temp = input_list[i] * input_smooth[idx]
input_split_result.append(temp)
idx += 1
else:
input_split_result.append(input_list[i])
quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8)
input_list = quant_input.split(token_cout.tolist())
input_scale_list = input_scale.split(token_cout.tolist())
return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx
dtype = hidden_states.dtype
if w1_quant_flag is None:
inner_size = w1.size(1) // 2 if gated else w1.size(1)
else:
inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2)
hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2)
input = hidden_states.view(-1, hidden_states.size(-1))
gating_output = gating_output.view(-1, gating_output.size(-1))
input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output)
output_list = []
idx = 0
need_quant = len(input_scale_list) != 0
if need_quant and w1_scale.dim() == 3:
w1_scale = w1_scale.transpose(0, 1).contiguous()
w2_scale = w2_scale.transpose(0, 1).contiguous()
if w1_quant_flag is not None:
w1_quant_group = w1_scale.size(1)
quant_wise = hidden_size // w1_quant_group
w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group)
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
if w2_quant_flag is not None:
w2_quant_group = w2_scale.size(1)
quant_wise = inner_size // w2_quant_group
w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group)
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
for i in range(expert_num): # start_expert_id : start_expert_id + expert_size
if i >= start_expert_id and i < start_expert_id + expert_size:
if (input_list[i].size(0) > 0):
if need_quant:
if w1_quant_flag is None:
gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i],
w1[idx], w1_scale[idx], dtype)
else:
gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i],
w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype,
quant_flag = w1_quant_flag[idx])
else:
gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0))
act_in = gemm1_out[:, :inner_size].float()
gate = gemm1_out[:, inner_size:]
act = act_mode_dict[act_mode]
gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \
act(act_in).to(dtype=dtype) * gate
if need_quant:
quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8)
if w2_quant_flag is None:
gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype)
else:
gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale,
w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype,
quant_flag = w2_quant_flag[idx])
else:
gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0))
output_list.append(gemm2_out)
idx += 1
else:
output_list.append(torch.zeros_like(input_list[i]))
output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight
output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype)
if residual is not None:
output = output + residual.view(input.shape)
return output.view(hidden_states.shape)
def test_fused_moe(self):
print("test_fused_moe")
batch, seq, hidden_size, inner_size = 3, 5, 512, 768
dtype_list = [torch.half]
if mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
start_expert_id_list = [0, 1, 3, 4, 5, 6]
expert_size_list = [8, 4, 3, 2, 3, 2]
for i in range(len(start_expert_id_list)):
start_expert_id = start_expert_id_list[i]
expert_size = expert_size_list[i]
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
torch_out = self.op_impl_base(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
# (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id).reshape(batch, seq, hidden_size)
tmo_out_3 = fused_moe(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True)
def test_pertoken_quant_fused_moe_tp(self):
print("test_pertoken_quant_fused_moe_tp")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
def test_moe_tp2_mixed_ep4_no_quant(self):
print("test_moe_tp2_mixed_ep4_no_quant")
tp_num = 2
ep_num = 4
expert_num = 32
expert_size = expert_num // ep_num
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size)
w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out1 = torch.zeros_like(hidden_states)
tmo_out2 = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
w1_curr_tp = w1[:, tp_idx, ...]
w2_curr_tp = w2[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
tmo_out1 += ops.fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out2 += fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
torch_out += self.op_impl_base(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True)
def test_moe_tp2_mixed_ep4_quant(self):
print("test_moe_tp2_mixed_ep4_quant")
tp_num = 2
ep_num = 4
expert_num = 32
expert_size = expert_num // ep_num
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
scale_s = 0.1 # avoid the occurrence of inf
eps = 0.01 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = torch.zeros_like(hidden_states)
for _ in range(tp_num):
w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size)
w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num)
input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size)
act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out = torch.zeros_like(hidden_states)
tmo_out1 = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
w1_curr_tp = w1[:, :, tp_idx, ...]
w2_curr_tp = w2[:, :, :, tp_idx, :]
act_smooth_tp = act_smooth[:, :, tp_idx, :]
w1_scale_tp = w1_scale[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous()
input_smooth_curr_ep = input_smooth[ep_idx].contiguous()
act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous()
w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous()
w2_scale_curr_ep = w2_scale[ep_idx].contiguous()
tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size]
router_logit, # [batch, seq, expert_num]
w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size]
w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num]
bias1,
bias2,
residual,
input_smooth_curr_ep, # [expert_size, hidden_size]
act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
w2_scale_curr_ep, # [expert_size, hidden_size]
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out1 += fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
input_smooth_curr_ep,
act_smooth_curr_ep,
w1_scale_curr_ep,
w2_scale_curr_ep,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
torch_out += self.op_impl_base(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
input_smooth_curr_ep,
act_smooth_curr_ep,
w1_scale_curr_ep,
w2_scale_curr_ep,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
def test_smq_fused_moe_random_tp(self):
print("test_smq_fused_moe_random_tp")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
for i in range(10):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 2048, 2)
inner_size = random.randrange(512, 2048, 2)
expert_num = random.randint(1, 40)
topk = random.randint(1,expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
self.__run_sq_case(*case)
def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1):
if expert_size == -1:
expert_size = expert_num
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
bias1, bias2 = None, None
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = self.op_impl_base(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id)
# # (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
def test_fused_moe_with_4D_w2(self):
print("test_fused_moe_with_4D_w2")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
dtype_list = [torch.half]
if mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
torch_out = self.op_impl_base(hidden_states, router_logit,
weight1, weight2,
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states, router_logit,
weight1, weight2,
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, 0)
# (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(),
bias1, bias2,
residual.view(-1, hidden_size) if residual is not None else None,
None, None, None, None,
topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_moe_tp2_mixed_ep4_with_4D_w2(self):
print("test_moe_tp2_mixed_ep4_with_4D_w2")
tp_num, ep_num = 2, 4
batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
assert inner_size % tp_num == 0
assert expert_num % ep_num == 0
expert_size = expert_num // ep_num
assert 4096 % inner_size == 0
block_e = 4096 // inner_size
assert expert_size % block_e == 0
if not mlu.is_bf16_supported():
data_type = torch.float16
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
residual = None
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size)
w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
new_inner_size = w2.shape[-1]
w1_curr_tp = w1[:, tp_idx, ...]
w2_curr_tp = w2[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep,
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None)
tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep,
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, start_expert_id)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_sq_fused_moe_with_4D_w2(self):
print("test_sq_fused_moe_with_4D_w2")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
assert 4096 % inner_size == 0
block_e = 4096 // inner_size
assert expert_num % block_e == 0
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
bias1, bias2 = None, None
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2,
bias1, bias2, residual, input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2,
bias1, bias2, residual, input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0)
# # (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1,
quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(),
bias1, bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize,
gated,
act_mode,
0).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
def test_sq_fused_moe_random_tp_quant_grouped(self):
print("test_sq_fused_moe_random_tp_quant_grouped")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 100):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 2048, 128)
inner_size = random.randrange(512, 2048, 512)
expert_num = random.randint(1, 40)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
quant_bit = random.choice([4, 8])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
self.__run_quant_grouped_case(*case)
def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128):
def get_quant_group(n, quant_group_size):
quant_group = n // quant_group_size
return quant_group if quant_group >= 1 else 1
if expert_size == -1:
expert_size = expert_num
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
bias1, bias2 = None, None
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
w1_quant_group = get_quant_group(hidden_size, quant_group_size)
w2_quant_group = get_quant_group(inner_size, quant_group_size)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
if quant_bit == 4:
quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2)
w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous()
w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous()
# split scale and transpose
def extract_scale(scale, start_expert, expert_size):
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
torch_out = self.op_impl_base(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1):
if expert_size == -1:
expert_size = expert_num
w1_quant_group = hidden_size // quant_wise
w2_quant_group = inner_size // quant_wise
w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4
w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4
w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated)
w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
bias1, bias2 = None, None
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
# split scale and transpose
def extract_scale(scale, start_expert, expert_size):
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
params = [hidden_states, router_logit,
w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]],
w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]],
bias1, bias2, residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk, renormalize, gated, act_mode, start_expert_id, 0, 0,
w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(),
w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()]
torch_out = self.op_impl_base(*params)
# (N, T, C)
tmo_out_1 = ops.fused_moe(*params)
tmo_out_2 = fused_moe(*params)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self):
print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 100):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(1024, 3072, 512)
inner_size = random.randrange(1024, 3072, 512)
expert_num = random.randint(1, 40)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
quant_wise = random.choice([128, 256, 512])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
self.__run_w4w8_mixed_case(*case, 0, -1)
def test_sq_fused_moe_single_quant_group(self):
print("test_sq_fused_moe_single_quant_group")
import random
random.seed(0)
act_mode = 'gelu'
batch = 9
seq = 10
hidden_size = 1664
inner_size = 512
expert_num = 15
topk = 2
gated = False
renormalize = False
quant_bit = 4
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
self.__run_quant_grouped_case(*case)
def test_inductor(self):
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16
start_expert_id, expert_size = 0, 8
hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32)
residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
args = (hidden_states, router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
None, None, residual, None, None, None, None, None, None,
topk, renormalize, gated, act_mode, 0, 0, 0)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
eps = 1e-5
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
args = (hidden_states, router_logit,
quant_w1, quant_w2, None, None, residual,
input_smooth, act_smooth, w1_scale, w2_scale,
None, None, topk, renormalize, gated, act_mode, 0, 0, 0)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
if __name__ == '__main__':
exit(run_unittest(TestFusedMOEOp))