947 lines
58 KiB
Python
Executable File
947 lines
58 KiB
Python
Executable File
import torch
|
|
import unittest
|
|
import torch_mlu_ops as ops
|
|
from common_utils import *
|
|
from typing import Union, List, Tuple
|
|
from torch.nn.parameter import Parameter
|
|
from torch.nn import functional as F
|
|
from torch_mlu import mlu
|
|
|
|
class TestFusedMOEOp(BtTestCase):
|
|
def run_gen_case(self, dic):
|
|
dump_data = dic.pop('dump_data')
|
|
if dump_data:
|
|
self.launch(*dic.values())
|
|
else:
|
|
hidden_states = create_tensor_from_dic(dic['hidden_states'])
|
|
gating_output = create_tensor_from_dic(dic['gating_output'])
|
|
w1 = create_tensor_from_dic(dic['w1'])
|
|
w2 = create_tensor_from_dic(dic['w2'])
|
|
if dic['w1']['dtype'] is not torch.int8:
|
|
w1 *= 0.1
|
|
w2 *= 0.1
|
|
bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1'])
|
|
bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2'])
|
|
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
|
|
input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05)
|
|
act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05)
|
|
w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05)
|
|
w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05)
|
|
topk = dic['topk']['data']
|
|
renormalize = dic['renormalize']['data']
|
|
gated = dic['gated']['data']
|
|
act_mode = dic['act_mode']['data']
|
|
start_expert_id = dic['start_expert_id']['data']
|
|
block_n = dic['block_n']['data']
|
|
cncl_comm = dic['cncl_comm']['data']
|
|
w1_quant_flag = dic['w1_quant_flag']['data']
|
|
w2_quant_flag = dic['w2_quant_flag']['data']
|
|
self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth,
|
|
act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode,
|
|
start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag)
|
|
|
|
def launch(self, *args):
|
|
base = self.op_impl_base(*args)
|
|
tmo_res = tmo.fused_moe(*args)
|
|
self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True)
|
|
|
|
def op_impl_base(self, *args):
|
|
hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \
|
|
act_smooth, w1_scale, w2_scale, topk, renormalize, \
|
|
gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args
|
|
if w2.dim() == 4:
|
|
w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1))
|
|
expert_num = gating_output.size(-1)
|
|
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
|
|
def router(hidden_states, router_logit):
|
|
router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
|
|
topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1)
|
|
if renormalize:
|
|
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
|
|
sorted_expert_id, indices = expert_id.int().flatten().sort()
|
|
nk = indices.size(0)
|
|
token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu()
|
|
expand_idx = indices.int() // topk
|
|
combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu")
|
|
combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu"))
|
|
input_expand = hidden_states[expand_idx]
|
|
input_list = input_expand.split(tuple(token_cout))
|
|
|
|
input_scale_list = []
|
|
input_split_result = []
|
|
if input_smooth is not None:
|
|
# do smooth quant on input
|
|
idx = 0
|
|
for i in range(expert_num):
|
|
if i >= start_expert_id and i < start_expert_id + expert_size:
|
|
if (input_list[i].size(0) > 0):
|
|
temp = input_list[i] * input_smooth[idx]
|
|
input_split_result.append(temp)
|
|
idx += 1
|
|
else:
|
|
input_split_result.append(input_list[i])
|
|
quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8)
|
|
input_list = quant_input.split(token_cout.tolist())
|
|
input_scale_list = input_scale.split(token_cout.tolist())
|
|
return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx
|
|
|
|
dtype = hidden_states.dtype
|
|
if w1_quant_flag is None:
|
|
inner_size = w1.size(1) // 2 if gated else w1.size(1)
|
|
else:
|
|
inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2)
|
|
hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2)
|
|
input = hidden_states.view(-1, hidden_states.size(-1))
|
|
gating_output = gating_output.view(-1, gating_output.size(-1))
|
|
input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output)
|
|
output_list = []
|
|
idx = 0
|
|
need_quant = len(input_scale_list) != 0
|
|
if need_quant and w1_scale.dim() == 3:
|
|
w1_scale = w1_scale.transpose(0, 1).contiguous()
|
|
w2_scale = w2_scale.transpose(0, 1).contiguous()
|
|
if w1_quant_flag is not None:
|
|
w1_quant_group = w1_scale.size(1)
|
|
quant_wise = hidden_size // w1_quant_group
|
|
w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group)
|
|
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
|
|
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
|
|
if w2_quant_flag is not None:
|
|
w2_quant_group = w2_scale.size(1)
|
|
quant_wise = inner_size // w2_quant_group
|
|
w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group)
|
|
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
|
|
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
|
|
for i in range(expert_num): # start_expert_id : start_expert_id + expert_size
|
|
if i >= start_expert_id and i < start_expert_id + expert_size:
|
|
if (input_list[i].size(0) > 0):
|
|
if need_quant:
|
|
if w1_quant_flag is None:
|
|
gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i],
|
|
w1[idx], w1_scale[idx], dtype)
|
|
else:
|
|
gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i],
|
|
w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype,
|
|
quant_flag = w1_quant_flag[idx])
|
|
else:
|
|
gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0))
|
|
act_in = gemm1_out[:, :inner_size].float()
|
|
gate = gemm1_out[:, inner_size:]
|
|
act = act_mode_dict[act_mode]
|
|
gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \
|
|
act(act_in).to(dtype=dtype) * gate
|
|
if need_quant:
|
|
quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8)
|
|
if w2_quant_flag is None:
|
|
gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype)
|
|
else:
|
|
gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale,
|
|
w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype,
|
|
quant_flag = w2_quant_flag[idx])
|
|
else:
|
|
gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0))
|
|
output_list.append(gemm2_out)
|
|
idx += 1
|
|
else:
|
|
output_list.append(torch.zeros_like(input_list[i]))
|
|
output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight
|
|
output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype)
|
|
if residual is not None:
|
|
output = output + residual.view(input.shape)
|
|
return output.view(hidden_states.shape)
|
|
|
|
def test_fused_moe(self):
|
|
print("test_fused_moe")
|
|
batch, seq, hidden_size, inner_size = 3, 5, 512, 768
|
|
dtype_list = [torch.half]
|
|
if mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
for dtype in dtype_list:
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
|
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
|
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
|
|
|
start_expert_id_list = [0, 1, 3, 4, 5, 6]
|
|
expert_size_list = [8, 4, 3, 2, 3, 2]
|
|
for i in range(len(start_expert_id_list)):
|
|
start_expert_id = start_expert_id_list[i]
|
|
expert_size = expert_size_list[i]
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
torch_out = self.op_impl_base(hidden_states,
|
|
router_logit,
|
|
weight1[start_expert_id:start_expert_id+expert_size],
|
|
weight2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id,
|
|
0, 0, None, None)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(hidden_states,
|
|
router_logit,
|
|
weight1[start_expert_id:start_expert_id+expert_size],
|
|
weight2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
# (N*T, C)
|
|
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
weight1[start_expert_id:start_expert_id+expert_size],
|
|
weight2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id).reshape(batch, seq, hidden_size)
|
|
tmo_out_3 = fused_moe(hidden_states,
|
|
router_logit,
|
|
weight1[start_expert_id:start_expert_id+expert_size],
|
|
weight2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True)
|
|
|
|
def test_pertoken_quant_fused_moe_tp(self):
|
|
print("test_pertoken_quant_fused_moe_tp")
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
|
|
|
|
def test_moe_tp2_mixed_ep4_no_quant(self):
|
|
print("test_moe_tp2_mixed_ep4_no_quant")
|
|
tp_num = 2
|
|
ep_num = 4
|
|
expert_num = 32
|
|
expert_size = expert_num // ep_num
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
residual = None
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
|
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
|
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
|
w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size)
|
|
w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num)
|
|
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
tmo_out1 = torch.zeros_like(hidden_states)
|
|
tmo_out2 = torch.zeros_like(hidden_states)
|
|
torch_out = torch.zeros_like(hidden_states)
|
|
for tp_idx in range(tp_num):
|
|
w1_curr_tp = w1[:, tp_idx, ...]
|
|
w2_curr_tp = w2[:, :, tp_idx, :]
|
|
for ep_idx in range(ep_num):
|
|
start_expert_id = ep_idx * expert_size
|
|
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
|
|
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
|
|
tmo_out1 += ops.fused_moe(hidden_states,
|
|
router_logit,
|
|
w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep,
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
tmo_out2 += fused_moe(hidden_states,
|
|
router_logit,
|
|
w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep,
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
torch_out += self.op_impl_base(hidden_states,
|
|
router_logit,
|
|
w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep,
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id,
|
|
0, 0, None, None)
|
|
tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size)
|
|
|
|
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True)
|
|
|
|
|
|
def test_moe_tp2_mixed_ep4_quant(self):
|
|
print("test_moe_tp2_mixed_ep4_quant")
|
|
tp_num = 2
|
|
ep_num = 4
|
|
expert_num = 32
|
|
expert_size = expert_num // ep_num
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
scale_s = 0.1 # avoid the occurrence of inf
|
|
eps = 0.01 # Avoid the occurrence of nan
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
residual = None
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
|
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
|
|
|
input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps
|
|
act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps
|
|
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
|
weight1 = weight1 / input_smooth.unsqueeze(1)
|
|
weight2 = weight2 / act_smooth.unsqueeze(1)
|
|
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8)
|
|
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8)
|
|
quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape)
|
|
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
|
torch_out = torch.zeros_like(hidden_states)
|
|
for _ in range(tp_num):
|
|
w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
|
|
act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
|
|
|
|
w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size)
|
|
w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num)
|
|
input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size)
|
|
act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
|
|
w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
|
|
w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size)
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
tmo_out = torch.zeros_like(hidden_states)
|
|
tmo_out1 = torch.zeros_like(hidden_states)
|
|
torch_out = torch.zeros_like(hidden_states)
|
|
for tp_idx in range(tp_num):
|
|
w1_curr_tp = w1[:, :, tp_idx, ...]
|
|
w2_curr_tp = w2[:, :, :, tp_idx, :]
|
|
act_smooth_tp = act_smooth[:, :, tp_idx, :]
|
|
w1_scale_tp = w1_scale[:, :, tp_idx, :]
|
|
for ep_idx in range(ep_num):
|
|
start_expert_id = ep_idx * expert_size
|
|
w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous()
|
|
w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous()
|
|
input_smooth_curr_ep = input_smooth[ep_idx].contiguous()
|
|
act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous()
|
|
w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous()
|
|
w2_scale_curr_ep = w2_scale[ep_idx].contiguous()
|
|
tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size]
|
|
router_logit, # [batch, seq, expert_num]
|
|
w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size]
|
|
w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num]
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth_curr_ep, # [expert_size, hidden_size]
|
|
act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
|
|
w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
|
|
w2_scale_curr_ep, # [expert_size, hidden_size]
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
tmo_out1 += fused_moe(hidden_states,
|
|
router_logit,
|
|
w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep,
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth_curr_ep,
|
|
act_smooth_curr_ep,
|
|
w1_scale_curr_ep,
|
|
w2_scale_curr_ep,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
torch_out += self.op_impl_base(hidden_states,
|
|
router_logit,
|
|
w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep,
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth_curr_ep,
|
|
act_smooth_curr_ep,
|
|
w1_scale_curr_ep,
|
|
w2_scale_curr_ep,
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id,
|
|
0, 0, None, None)
|
|
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
|
|
|
|
def test_smq_fused_moe_random_tp(self):
|
|
print("test_smq_fused_moe_random_tp")
|
|
import random
|
|
random.seed(0)
|
|
act_mode = 'gelu'
|
|
case_list = set()
|
|
for i in range(10):
|
|
batch = random.randint(1, 10)
|
|
seq = random.randint(1, 10)
|
|
hidden_size = random.randrange(512, 2048, 2)
|
|
inner_size = random.randrange(512, 2048, 2)
|
|
expert_num = random.randint(1, 40)
|
|
topk = random.randint(1,expert_num)
|
|
gated = random.choice([True, False])
|
|
renormalize = random.choice([True, False])
|
|
data_type = random.choice([torch.bfloat16, torch.float16])
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
|
|
if case in case_list:
|
|
continue
|
|
case_list.add(case)
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
self.__run_sq_case(*case)
|
|
|
|
def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1):
|
|
if expert_size == -1:
|
|
expert_size = expert_num
|
|
scale_s = 0.01 # avoid the occurrence of inf
|
|
eps = 0.1 # Avoid the occurrence of nan
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
|
|
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
|
|
bias1, bias2 = None, None
|
|
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
|
weight1 = weight1 / input_smooth.unsqueeze(1)
|
|
weight2 = weight2 / act_smooth.unsqueeze(1)
|
|
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
|
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
|
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
|
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
|
torch_out = self.op_impl_base(hidden_states,
|
|
router_logit,
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
w1_scale[start_expert_id:start_expert_id+expert_size],
|
|
w2_scale[start_expert_id:start_expert_id+expert_size],
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id,
|
|
0, 0, None, None)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(hidden_states,
|
|
router_logit,
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
w1_scale[start_expert_id:start_expert_id+expert_size],
|
|
w2_scale[start_expert_id:start_expert_id+expert_size],
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
# # (N*T, C)
|
|
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
w1_scale[start_expert_id:start_expert_id+expert_size],
|
|
w2_scale[start_expert_id:start_expert_id+expert_size],
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id).view(batch, seq, hidden_size)
|
|
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
w1_scale[start_expert_id:start_expert_id+expert_size],
|
|
w2_scale[start_expert_id:start_expert_id+expert_size],
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id).view(batch, seq, hidden_size)
|
|
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
|
|
|
def test_fused_moe_with_4D_w2(self):
|
|
print("test_fused_moe_with_4D_w2")
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
dtype_list = [torch.half]
|
|
if mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
for dtype in dtype_list:
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype
|
|
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
|
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
|
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
|
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
|
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
|
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
|
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
torch_out = self.op_impl_base(hidden_states, router_logit,
|
|
weight1, weight2,
|
|
bias1, bias2, residual, None, None, None, None,
|
|
topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(hidden_states, router_logit,
|
|
weight1, weight2,
|
|
bias1, bias2, residual, None, None, None, None,
|
|
topk, renormalize, gated, act_mode, 0)
|
|
# (N*T, C)
|
|
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(),
|
|
bias1, bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
None, None, None, None,
|
|
topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size)
|
|
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
|
|
def test_moe_tp2_mixed_ep4_with_4D_w2(self):
|
|
print("test_moe_tp2_mixed_ep4_with_4D_w2")
|
|
tp_num, ep_num = 2, 4
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
|
assert inner_size % tp_num == 0
|
|
assert expert_num % ep_num == 0
|
|
expert_size = expert_num // ep_num
|
|
assert 4096 % inner_size == 0
|
|
block_e = 4096 // inner_size
|
|
assert expert_size % block_e == 0
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
|
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
|
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
|
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
|
residual = None
|
|
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
|
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
|
w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size)
|
|
w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1)
|
|
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
|
tmo_out = torch.zeros_like(hidden_states)
|
|
torch_out = torch.zeros_like(hidden_states)
|
|
for tp_idx in range(tp_num):
|
|
new_inner_size = w2.shape[-1]
|
|
w1_curr_tp = w1[:, tp_idx, ...]
|
|
w2_curr_tp = w2[:, :, tp_idx, :]
|
|
for ep_idx in range(ep_num):
|
|
start_expert_id = ep_idx * expert_size
|
|
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
|
|
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
|
|
torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
|
bias1, bias2, residual, None, None, None, None,
|
|
topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None)
|
|
tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep,
|
|
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
|
bias1, bias2, residual, None, None, None, None,
|
|
topk, renormalize, gated, act_mode, start_expert_id)
|
|
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
|
|
def test_sq_fused_moe_with_4D_w2(self):
|
|
print("test_sq_fused_moe_with_4D_w2")
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
assert 4096 % inner_size == 0
|
|
block_e = 4096 // inner_size
|
|
assert expert_num % block_e == 0
|
|
scale_s = 0.01 # avoid the occurrence of inf
|
|
eps = 0.1 # Avoid the occurrence of nan
|
|
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
|
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
|
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
|
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
|
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
bias1, bias2 = None, None
|
|
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
|
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
|
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
|
weight1 = weight1 / input_smooth.unsqueeze(1)
|
|
weight2 = weight2 / act_smooth.unsqueeze(1)
|
|
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
|
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
|
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
|
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
|
torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2,
|
|
bias1, bias2, residual, input_smooth, act_smooth,
|
|
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2,
|
|
bias1, bias2, residual, input_smooth, act_smooth,
|
|
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0)
|
|
# # (N*T, C)
|
|
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
quant_w1,
|
|
quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(),
|
|
bias1, bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
input_smooth, act_smooth,
|
|
w1_scale, w2_scale, topk, renormalize,
|
|
gated,
|
|
act_mode,
|
|
0).view(batch, seq, hidden_size)
|
|
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
|
|
|
|
def test_sq_fused_moe_random_tp_quant_grouped(self):
|
|
print("test_sq_fused_moe_random_tp_quant_grouped")
|
|
import random
|
|
random.seed(0)
|
|
act_mode = 'gelu'
|
|
case_list = set()
|
|
while (len(case_list) < 100):
|
|
batch = random.randint(1, 10)
|
|
seq = random.randint(1, 10)
|
|
hidden_size = random.randrange(512, 2048, 128)
|
|
inner_size = random.randrange(512, 2048, 512)
|
|
expert_num = random.randint(1, 40)
|
|
topk = random.randint(1, expert_num)
|
|
gated = random.choice([True, False])
|
|
renormalize = random.choice([True, False])
|
|
quant_bit = random.choice([4, 8])
|
|
data_type = random.choice([torch.bfloat16, torch.float16])
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
|
|
if case in case_list:
|
|
continue
|
|
case_list.add(case)
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
|
|
self.__run_quant_grouped_case(*case)
|
|
|
|
def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128):
|
|
def get_quant_group(n, quant_group_size):
|
|
quant_group = n // quant_group_size
|
|
return quant_group if quant_group >= 1 else 1
|
|
|
|
if expert_size == -1:
|
|
expert_size = expert_num
|
|
scale_s = 0.01 # avoid the occurrence of inf
|
|
eps = 0.1 # Avoid the occurrence of nan
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
|
|
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
|
|
bias1, bias2 = None, None
|
|
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
|
weight1 = weight1 / input_smooth.unsqueeze(1)
|
|
weight2 = weight2 / act_smooth.unsqueeze(1)
|
|
w1_quant_group = get_quant_group(hidden_size, quant_group_size)
|
|
w2_quant_group = get_quant_group(inner_size, quant_group_size)
|
|
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group)
|
|
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group)
|
|
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
|
if quant_bit == 4:
|
|
quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2)
|
|
w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous()
|
|
w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous()
|
|
# split scale and transpose
|
|
def extract_scale(scale, start_expert, expert_size):
|
|
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
|
|
|
|
torch_out = self.op_impl_base(hidden_states,
|
|
router_logit,
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
extract_scale(w1_scale, start_expert_id, expert_size),
|
|
extract_scale(w2_scale, start_expert_id, expert_size),
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id,
|
|
0, 0, None, None)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(hidden_states,
|
|
router_logit,
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
extract_scale(w1_scale, start_expert_id, expert_size),
|
|
extract_scale(w2_scale, start_expert_id, expert_size),
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id)
|
|
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
|
|
router_logit.view(-1, expert_num),
|
|
quant_w1[start_expert_id:start_expert_id+expert_size],
|
|
quant_w2[start_expert_id:start_expert_id+expert_size],
|
|
bias1,
|
|
bias2,
|
|
residual.view(-1, hidden_size) if residual is not None else None,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
extract_scale(w1_scale, start_expert_id, expert_size),
|
|
extract_scale(w2_scale, start_expert_id, expert_size),
|
|
topk,
|
|
renormalize,
|
|
gated,
|
|
act_mode,
|
|
start_expert_id).view(batch, seq, hidden_size)
|
|
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
|
|
|
def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1):
|
|
if expert_size == -1:
|
|
expert_size = expert_num
|
|
w1_quant_group = hidden_size // quant_wise
|
|
w2_quant_group = inner_size // quant_wise
|
|
w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4
|
|
w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4
|
|
w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated)
|
|
w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size
|
|
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
|
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
|
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
|
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
|
input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
|
act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
|
bias1, bias2 = None, None
|
|
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
|
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
|
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
|
|
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
|
|
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
|
|
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
|
|
|
|
# split scale and transpose
|
|
def extract_scale(scale, start_expert, expert_size):
|
|
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
|
|
|
|
params = [hidden_states, router_logit,
|
|
w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]],
|
|
w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]],
|
|
bias1, bias2, residual,
|
|
input_smooth[start_expert_id:start_expert_id+expert_size],
|
|
act_smooth[start_expert_id:start_expert_id+expert_size],
|
|
extract_scale(w1_scale, start_expert_id, expert_size),
|
|
extract_scale(w2_scale, start_expert_id, expert_size),
|
|
topk, renormalize, gated, act_mode, start_expert_id, 0, 0,
|
|
w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(),
|
|
w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()]
|
|
torch_out = self.op_impl_base(*params)
|
|
# (N, T, C)
|
|
tmo_out_1 = ops.fused_moe(*params)
|
|
tmo_out_2 = fused_moe(*params)
|
|
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
|
|
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
|
|
|
|
def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self):
|
|
print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed")
|
|
import random
|
|
random.seed(0)
|
|
act_mode = 'gelu'
|
|
case_list = set()
|
|
while (len(case_list) < 100):
|
|
batch = random.randint(1, 10)
|
|
seq = random.randint(1, 10)
|
|
hidden_size = random.randrange(1024, 3072, 512)
|
|
inner_size = random.randrange(1024, 3072, 512)
|
|
expert_num = random.randint(1, 40)
|
|
topk = random.randint(1, expert_num)
|
|
gated = random.choice([True, False])
|
|
renormalize = random.choice([True, False])
|
|
quant_wise = random.choice([128, 256, 512])
|
|
data_type = random.choice([torch.bfloat16, torch.float16])
|
|
if not mlu.is_bf16_supported():
|
|
data_type = torch.float16
|
|
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode)
|
|
if case in case_list:
|
|
continue
|
|
case_list.add(case)
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
|
|
self.__run_w4w8_mixed_case(*case, 0, -1)
|
|
|
|
def test_sq_fused_moe_single_quant_group(self):
|
|
print("test_sq_fused_moe_single_quant_group")
|
|
import random
|
|
random.seed(0)
|
|
act_mode = 'gelu'
|
|
batch = 9
|
|
seq = 10
|
|
hidden_size = 1664
|
|
inner_size = 512
|
|
expert_num = 15
|
|
topk = 2
|
|
gated = False
|
|
renormalize = False
|
|
quant_bit = 4
|
|
data_type = torch.float16
|
|
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
|
|
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
|
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
|
|
self.__run_quant_grouped_case(*case)
|
|
|
|
def test_inductor(self):
|
|
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
|
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16
|
|
start_expert_id, expert_size = 0, 8
|
|
hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
|
|
router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32)
|
|
residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
|
|
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
|
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
|
args = (hidden_states, router_logit,
|
|
weight1[start_expert_id:start_expert_id+expert_size],
|
|
weight2[start_expert_id:start_expert_id+expert_size],
|
|
None, None, residual, None, None, None, None, None, None,
|
|
topk, renormalize, gated, act_mode, 0, 0, 0)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
|
|
|
|
eps = 1e-5
|
|
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
|
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
|
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
|
weight1 = weight1 / input_smooth.unsqueeze(1)
|
|
weight2 = weight2 / act_smooth.unsqueeze(1)
|
|
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
|
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
|
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
|
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
|
args = (hidden_states, router_logit,
|
|
quant_w1, quant_w2, None, None, residual,
|
|
input_smooth, act_smooth, w1_scale, w2_scale,
|
|
None, None, topk, renormalize, gated, act_mode, 0, 0, 0)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestFusedMOEOp))
|