This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,147 @@
import torch
import torch_mlu
import unittest
import math
import torch_mlu_ops as ops
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
bias, softmax_scale, is_causal, world_size = 1):
q_list = q.chunk(world_size, dim=2)
k_list = k.chunk(world_size, dim=2)
v_list = v.chunk(world_size, dim=2)
smooth_list = smooth.chunk(world_size, dim=0)
quant_weight_list = quant_weight.chunk(world_size, dim=1)
quant_weight_list = [w.contiguous() for w in quant_weight_list]
output1 = torch.zeros(q.size(0) * q.size(1), q.size(2) * q.size(3), dtype=q.dtype).mlu()
for i in range(world_size):
attn_output = ops.flash_attention(q_list[i], k_list[i], v_list[i], None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth_list[i], None)
output1 += ops.smooth_quant_matmul(quant_input, input_scale,
quant_weight_list[i], weight_scale, q.dtype, bias if i == 0 else None)
attn_output = ops.flash_attention(q, k, v, None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
output2 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
return output1, output2
def tp_flash_attn_sq_mm(rank, *args):
world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, base_output = args
q_cpu, k_cpu, v_cpu, smooth_cpu = q.cpu(), k.cpu(), v.cpu(), smooth.cpu()
quant_weight_cpu, weight_scale_cpu, bias_cpu = quant_weight.cpu(), weight_scale.cpu(), bias.cpu()
base_output_cpu = base_output.cpu()
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
head_num_q = q.size(2)
head_num_kv = k.size(2)
seq = q.size(1)
assert head_num_q % world_size == 0
assert head_num_kv % world_size == 0
head_num_q_tp = head_num_q // world_size
head_num_kv_tp = head_num_kv // world_size
q = q_cpu.mlu()
k = k_cpu.mlu()
v = v_cpu.mlu()
smooth = smooth_cpu.mlu()
quant_weight = quant_weight_cpu.mlu()
weight_scale = weight_scale_cpu.mlu()
# Note: only tp0 add bias
bias = bias_cpu.mlu() if rank == 0 else None
q_list = q.chunk(world_size, dim=2)
k_list = k.chunk(world_size, dim=2)
v_list = v.chunk(world_size, dim=2)
smooth_list = smooth.chunk(world_size, dim=0)
quant_weight_list = quant_weight.chunk(world_size, dim=1)
quant_weight_list = [w.contiguous() for w in quant_weight_list]
# test pad mode
output_pad = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_list[rank], k_list[rank], v_list[rank],
None, None, None, None,
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
seq, softmax_scale, is_causal)
assertTensorsEqual(output_pad.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
# test pack mode
cu_seq_lens_q = torch.tensor([0, seq], dtype=torch.int32).mlu()
cu_seq_lens_k = torch.tensor([0, seq], dtype=torch.int32).mlu()
q_pack = q_list[rank].flatten(0, 1)
k_pack = k_list[rank].flatten(0, 1)
v_pack = v_list[rank].flatten(0, 1)
output_pack = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_pack, k_pack, v_pack,
cu_seq_lens_q, cu_seq_lens_k, None, None,
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
seq, softmax_scale, is_causal)
assertTensorsEqual(output_pack.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestFlashAttnSqMMAllreduce(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_flash_attn_sq_mm_split_seq due to ASan issues")
def test_flash_attn_sq_mm_split_seq(self):
batch, seq, head_num_q, head_num_kv, head_size, is_causal, block_seq = 16, 1024, 8, 1, 128, True, 4
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
hidden_size = head_num_q * head_size
softmax_scale = 1 / math.sqrt(head_size)
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
bias = torch.zeros(hidden_size, dtype=dtype).mlu()
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
q = qkv[:, :, :head_num_q, :]
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
v = qkv[:, :, head_num_q+head_num_kv:, :]
attn_output = ops.flash_attention(q, k, v, None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
output1 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
output_list = []
block_size = seq // block_seq
for i in range(block_seq):
start = i * block_size
end = seq if i == block_seq - 1 else (i + 1) * block_size
attn_output = ops.flash_attention(q[:,start:end], k[:,:end], v[:,:end], None, None, None,
None, None, end - start, end, softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
out = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
output_list.append(out)
output2 = torch.cat(output_list, dim=0)
output2 = output2.reshape(block_seq, batch, block_size, hidden_size).transpose(0, 1).reshape(batch*seq, hidden_size)
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.0045, use_MSE=True, use_RAE=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_flash_attn_sq_mm_allreduce(self):
world_size = min(torch_mlu.mlu.device_count(), 8)
batch, seq, head_num_q, head_num_kv, head_size, is_causal = 1, 8192, 64, 8, 128, True
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
for i in range(1):
# seq = random.randint(1, 32768)
# is_causal = bool(random.randint(0, 1))
print("=============test[{}]: seq = {}, causal = {}===============".format(i, seq, is_causal))
hidden_size = head_num_q * head_size
softmax_scale = 1 / math.sqrt(head_size)
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
bias = torch.randn(hidden_size, dtype=dtype).mlu()
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
q = qkv[:, :, :head_num_q, :]
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
v = qkv[:, :, head_num_q+head_num_kv:, :]
output1, output2 = flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
bias, softmax_scale, is_causal, world_size)
args = world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, output1
mp.spawn(tp_flash_attn_sq_mm, args, nprocs=world_size, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestFlashAttnSqMMAllreduce))

View File

@@ -0,0 +1,109 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from itertools import product
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def mm_split_k(rank, *args):
torch.manual_seed(0)
mat_m, mat_n, mat_k, has_bias, has_res, dtype, world_size, block_m = args
assert mat_k % world_size == 0, f"mat_k{mat_k} must be divisible by tp{world_size}"
block_k = mat_k // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
alpha = 0.625
beta = 1.0 if has_res else 0.
input = torch.randn((mat_m, mat_k), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') if has_res else None
weight = torch.randn((mat_n, mat_k), dtype=dtype, device="mlu")
bias = torch.randn(mat_n, dtype=dtype, device="mlu") if has_bias else None
pt_output = torch.matmul(input, weight.permute(1, 0))
if has_bias:
pt_output = pt_output + bias
pt_output *= alpha
if has_res:
pt_output += beta * residual
if has_bias:
bias *= alpha
input = input[..., start_k:end_k].contiguous()
weight = weight[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
residual = residual if (residual is not None and rank == 0) else None
beta = beta if (residual is not None and rank == 0) else 0.
output = ops.matmul_allreduce(cncl_comm, input, weight, bias, residual, alpha, beta, block_m)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestMatMulAllReduceOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce due to ASan issues")
def test_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [1680]
has_res_list = [False, True]
has_bias_list = [False, True]
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
block_m = 4
args = product(mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, dtype_list)
for mat_m, mat_n, mat_k, has_bias, has_res, dtype in args:
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp={}, block_m={}, testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m), flush=True)
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m]
mp.spawn(mm_split_k, param, nprocs=device_n, join=True)
@unittest.skip("not test")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce_random due to ASan issues")
def test_matmul_allreduce_random(self):
import random
random.seed(0)
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for i in range(10):
tp_num = random.randint(1, device_n)
mat_m = random.randint(1, 4096)
mat_n = random.choice([512, 1024, 2048, 4096])
k_start = (1024 // tp_num) * tp_num
mat_k = random.randrange(k_start, 10240, tp_num)
has_res = random.choice([False, True])
has_bias = random.choice([False, True])
dtype = random.choice(dtype_list)
block_m = random.randint(1, 10)
assert mat_k % tp_num == 0, f"mat_k{mat_k} must be divisible by tp_num{tp_num}"
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m), flush=True)
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m]
mp.spawn(mm_split_k, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestMatMulAllReduceOp))

View File

@@ -0,0 +1,305 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.multiprocessing as mp
import math
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def moe_split_inner_size(rank, *args):
torch.manual_seed(0)
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
block_k = inner_size // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
scale_s = 0.01 # avoid the occurrence of inf
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=dtype) * scale_s
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=dtype)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=dtype) * scale_s
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=dtype)
weight1 = weight1.view(expert_num, -1, inner_size, hidden_size)
weight1 = weight1[:, :, start_k:end_k]
weight1 = weight1.reshape(expert_num, -1, hidden_size).contiguous()
weight2 = weight2[..., start_k:end_k].contiguous()
if bias1 is not None:
bias1 = bias1.view(expert_num, -1, inner_size)
bias1 = bias1[..., start_k:end_k]
bias1 = bias1.reshape(expert_num, -1).contiguous()
residual = residual if (residual is not None and rank == 0) else None
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode]
output = ops.fused_moe(*param)
all_reduce(output, ReduceOp.SUM, group=pg)
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
output1 = ops.fused_moe(*param)
new_inner_size = weight2.shape[-1]
block_e = 4096 // new_inner_size
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
param = [hidden_states, router_logit, weight1, weight2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
output2 = ops.fused_moe(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
def sq_moe_split_inner_size(rank, *args):
torch.manual_seed(0)
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
block_k = inner_size // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
scale_s = 0.1 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
bias1, bias2 = None, None
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=dtype)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=dtype)
weight2 = torch.normal(0, 0.01, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=dtype)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
quant_w1 = quant_w1.view(expert_num, -1, inner_size, hidden_size)
quant_w1 = quant_w1[:, :, start_k:end_k]
quant_w1 = quant_w1.reshape(expert_num, -1, hidden_size).contiguous()
quant_w2 = quant_w2[..., start_k:end_k].contiguous()
if bias1 is not None:
bias1 = bias1.view(expert_num, -1, inner_size)
bias1 = bias1[..., start_k:end_k]
bias1 = bias1.reshape(expert_num, -1).contiguous()
if w1_scale is not None:
w1_scale = w1_scale.view(expert_num, -1, inner_size)
w1_scale = w1_scale[..., start_k:end_k]
w1_scale = w1_scale.reshape(expert_num, -1).contiguous()
if act_smooth is not None:
act_smooth = act_smooth[..., start_k:end_k].contiguous()
residual = residual if (residual is not None and rank == 0) else None
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode]
output = ops.fused_moe(*param)
all_reduce(output, ReduceOp.SUM, group=pg)
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0,
block_n, cncl_comm]
output1 = ops.fused_moe(*param)
new_inner_size = quant_w2.shape[-1]
block_e = 4096 // new_inner_size
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
param = [hidden_states, router_logit, quant_w1,
quant_w2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk,
renormalize, gated, act_mode, 0, block_n, cncl_comm]
output2 = ops.fused_moe(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestFusedMOEAllReduceOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_fused_moe_allreduce due to ASan issues")
def test_single_fused_moe_allreduce(self):
print("test_single_fused_moe_allreduce")
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_n = 2
batch, seq, hidden_size, inner_size = 1, 1024, 8192, 8192
expert_num, topk, gated, renormalize, act_mode, dtype = 8, 2, True, True, 'silu', torch.float16
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
mp.spawn(moe_split_inner_size, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_fused_moe_allreduce due to ASan issues")
def test_random_fused_moe_allreduce(self):
print("test_random_fused_moe_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
while(len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 1024, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
k_start = (512 // tp_num) * tp_num
inner_size = random.randrange(k_start, 1024, tp_num * 2)
expert_num = random.randint(1, 32)
topk = random.randint(1,expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice(dtype_list)
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(moe_split_inner_size, param, nprocs=tp_num, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_sq_fused_moe_allreduce due to ASan issues")
def test_sq_fused_moe_allreduce(self):
print("test_sq_fused_moe_allreduce")
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_n = 2
batch, seq, hidden_size, inner_size = 5, 9, 8192, 8192
expert_num, topk, gated, renormalize, act_mode, dtype = 20, 16, False, True, 'gelu', torch.float16
print(f"sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
mp.spawn(sq_moe_split_inner_size, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_allreduce due to ASan issues")
def test_random_sq_fused_moe_allreduce(self):
print("test_random_sq_fused_moe_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
while(len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(256, 512, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
k_start = (512 // tp_num) * tp_num
inner_size = random.randrange(k_start, 1024, tp_num * 2)
expert_num = random.randint(1, 32)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice([ torch.float16])
if torch_mlu.mlu.get_device_name() == 'MLU370':
dtype = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_with_4D_w2_allreduce due to ASan issues")
def test_random_sq_fused_moe_with_4D_w2_allreduce(self):
print("test_random_sq_fused_moe_with_4D_w2_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 4096, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
inner_size_per_tp = random.choice([256, 512])
inner_size = inner_size_per_tp * tp_num
expert_num_base = random.randint(1, 4)
expert_num_factor = 4096 // inner_size_per_tp
expert_num = expert_num_base * expert_num_factor
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice([ torch.float16])
if torch_mlu.mlu.get_device_name() == 'MLU370':
dtype = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestFusedMOEAllReduceOp))

View File

@@ -0,0 +1,157 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def compute_weight_only_scale(weight, quant_bit):
int_max = float(2 ** (quant_bit - 1) - 1)
weight_max = torch.max(torch.abs(weight), axis=1, keepdims=True)
weight_scale = torch.div(int_max, weight_max[0])
weight_int = torch.mul(weight, weight_scale)
weight_int = weight_int.type(torch.int8)
weight_scale_recip = torch.div(weight_max[0], int_max).type(torch.float).squeeze()
return weight_int, weight_scale_recip
def quant_mm_split_k(rank, *args):
torch.manual_seed(0)
M, N, K, has_bias, has_res, dtype, block_m, world_size = args
assert K % world_size == 0
block_k = K // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
a = torch.randn(M, K, device="mlu", dtype=dtype)
b = torch.randn(N, K, device="mlu", dtype=dtype)
b_int, gemm_output_scale = compute_weight_only_scale(b, 8)
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
torch_quant_matmul = QuantMatmul(b_int, None, None, None, None, gemm_output_scale, dtype)
pt_output = torch_quant_matmul(a).detach()
a = a[..., start_k:end_k].contiguous()
b_int = b_int[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
c = c if (c is not None and rank == 0) else None
param = [cncl_comm, a, None, None, b_int, None, None, bias, c, None,
None, gemm_output_scale, None, "half", "weight_only", "quantize_none",
"quantize_per_channel", 8, 1.0, 1.0, False, True, block_m]
# beta = beta if (c is not None and rank == 0) else 0.
output = ops._ops.quant_matmul_allreduce(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
def sq_mm_split_k(rank, *args):
torch.manual_seed(0)
M, N, K, has_bias, has_res, dtype, world_size, block_m = args
assert K % world_size == 0
block_k = K // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
a = torch.randint(-10, 10, (M, K), dtype=torch.int8).mlu()
b = torch.randint(-10, 10, (N, K), dtype=torch.int8).mlu()
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
torch_quant_matmul = QuantMatmul(b, bias, c, a_scale, b_scale, None, dtype)
pt_output = torch_quant_matmul(a).detach()
a = a[..., start_k:end_k].contiguous()
b = b[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
c = c if (c is not None and rank == 0) else None
# beta = beta if (c is not None and rank == 0) else 0.
param = [cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, 1.0, 1.0, block_m]
output = ops.smooth_quant_matmul_allreduce(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestGptQuantMatmulOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
# weight only, no bias, no residual, no activation
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_weight_only_matmul_allreduce due to ASan issues")
def test_single_weight_only_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_m = 8
M, K, N = 32, 256, 128
has_bias = False
has_res = False
dtype = torch.half
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
param = [M, N, K, has_bias, has_res, dtype, block_m, device_n]
mp.spawn(quant_mm_split_k, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_sq_matmul_allreduce due to ASan issues")
def test_single_sq_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_m = 1
M, K, N = 2598, 1024, 1024
has_bias = False
has_res = False
dtype = torch.half
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
param = [M, N, K, has_bias, has_res, dtype, device_n, block_m]
mp.spawn(sq_mm_split_k, param, nprocs=device_n, join=True)
@unittest.skip("not test")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_mm_allreduce due to ASan issues")
def test_random_sq_mm_allreduce(self):
import random
random.seed(0)
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
for i in range(10):
tp_num = random.randint(1, device_n)
M = random.randint(1, 4096)
N = random.choice([512, 1024, 2048, 4096])
k_start = (1024 // tp_num) * tp_num
K = random.randrange(k_start, 2048, tp_num)
has_res = random.choice([False, True])
has_bias = random.choice([False, True])
dtype = random.choice(dtype_list)
block_m = random.randint(1, 10)
print("M={}, N={}, K={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, tp_num, block_m), flush=True)
param = [M, N, K, has_bias, has_res, dtype, tp_num, block_m]
mp.spawn(sq_mm_split_k, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestGptQuantMatmulOp))