forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import math
|
||||
import torch_mlu_ops as ops
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
|
||||
bias, softmax_scale, is_causal, world_size = 1):
|
||||
q_list = q.chunk(world_size, dim=2)
|
||||
k_list = k.chunk(world_size, dim=2)
|
||||
v_list = v.chunk(world_size, dim=2)
|
||||
smooth_list = smooth.chunk(world_size, dim=0)
|
||||
quant_weight_list = quant_weight.chunk(world_size, dim=1)
|
||||
quant_weight_list = [w.contiguous() for w in quant_weight_list]
|
||||
output1 = torch.zeros(q.size(0) * q.size(1), q.size(2) * q.size(3), dtype=q.dtype).mlu()
|
||||
for i in range(world_size):
|
||||
attn_output = ops.flash_attention(q_list[i], k_list[i], v_list[i], None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth_list[i], None)
|
||||
output1 += ops.smooth_quant_matmul(quant_input, input_scale,
|
||||
quant_weight_list[i], weight_scale, q.dtype, bias if i == 0 else None)
|
||||
attn_output = ops.flash_attention(q, k, v, None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
output2 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
return output1, output2
|
||||
|
||||
def tp_flash_attn_sq_mm(rank, *args):
|
||||
world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, base_output = args
|
||||
q_cpu, k_cpu, v_cpu, smooth_cpu = q.cpu(), k.cpu(), v.cpu(), smooth.cpu()
|
||||
quant_weight_cpu, weight_scale_cpu, bias_cpu = quant_weight.cpu(), weight_scale.cpu(), bias.cpu()
|
||||
base_output_cpu = base_output.cpu()
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
head_num_q = q.size(2)
|
||||
head_num_kv = k.size(2)
|
||||
seq = q.size(1)
|
||||
assert head_num_q % world_size == 0
|
||||
assert head_num_kv % world_size == 0
|
||||
head_num_q_tp = head_num_q // world_size
|
||||
head_num_kv_tp = head_num_kv // world_size
|
||||
q = q_cpu.mlu()
|
||||
k = k_cpu.mlu()
|
||||
v = v_cpu.mlu()
|
||||
smooth = smooth_cpu.mlu()
|
||||
quant_weight = quant_weight_cpu.mlu()
|
||||
weight_scale = weight_scale_cpu.mlu()
|
||||
# Note: only tp0 add bias
|
||||
bias = bias_cpu.mlu() if rank == 0 else None
|
||||
q_list = q.chunk(world_size, dim=2)
|
||||
k_list = k.chunk(world_size, dim=2)
|
||||
v_list = v.chunk(world_size, dim=2)
|
||||
smooth_list = smooth.chunk(world_size, dim=0)
|
||||
quant_weight_list = quant_weight.chunk(world_size, dim=1)
|
||||
quant_weight_list = [w.contiguous() for w in quant_weight_list]
|
||||
# test pad mode
|
||||
output_pad = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_list[rank], k_list[rank], v_list[rank],
|
||||
None, None, None, None,
|
||||
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
|
||||
seq, softmax_scale, is_causal)
|
||||
assertTensorsEqual(output_pad.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
# test pack mode
|
||||
cu_seq_lens_q = torch.tensor([0, seq], dtype=torch.int32).mlu()
|
||||
cu_seq_lens_k = torch.tensor([0, seq], dtype=torch.int32).mlu()
|
||||
q_pack = q_list[rank].flatten(0, 1)
|
||||
k_pack = k_list[rank].flatten(0, 1)
|
||||
v_pack = v_list[rank].flatten(0, 1)
|
||||
output_pack = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_pack, k_pack, v_pack,
|
||||
cu_seq_lens_q, cu_seq_lens_k, None, None,
|
||||
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
|
||||
seq, softmax_scale, is_causal)
|
||||
assertTensorsEqual(output_pack.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
class TestFlashAttnSqMMAllreduce(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_flash_attn_sq_mm_split_seq due to ASan issues")
|
||||
def test_flash_attn_sq_mm_split_seq(self):
|
||||
batch, seq, head_num_q, head_num_kv, head_size, is_causal, block_seq = 16, 1024, 8, 1, 128, True, 4
|
||||
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
|
||||
hidden_size = head_num_q * head_size
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
|
||||
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
|
||||
bias = torch.zeros(hidden_size, dtype=dtype).mlu()
|
||||
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
|
||||
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
|
||||
q = qkv[:, :, :head_num_q, :]
|
||||
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
|
||||
v = qkv[:, :, head_num_q+head_num_kv:, :]
|
||||
attn_output = ops.flash_attention(q, k, v, None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
output1 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
output_list = []
|
||||
block_size = seq // block_seq
|
||||
for i in range(block_seq):
|
||||
start = i * block_size
|
||||
end = seq if i == block_seq - 1 else (i + 1) * block_size
|
||||
attn_output = ops.flash_attention(q[:,start:end], k[:,:end], v[:,:end], None, None, None,
|
||||
None, None, end - start, end, softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
out = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
output_list.append(out)
|
||||
output2 = torch.cat(output_list, dim=0)
|
||||
output2 = output2.reshape(block_seq, batch, block_size, hidden_size).transpose(0, 1).reshape(batch*seq, hidden_size)
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.0045, use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_flash_attn_sq_mm_allreduce(self):
|
||||
world_size = min(torch_mlu.mlu.device_count(), 8)
|
||||
batch, seq, head_num_q, head_num_kv, head_size, is_causal = 1, 8192, 64, 8, 128, True
|
||||
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
|
||||
for i in range(1):
|
||||
# seq = random.randint(1, 32768)
|
||||
# is_causal = bool(random.randint(0, 1))
|
||||
print("=============test[{}]: seq = {}, causal = {}===============".format(i, seq, is_causal))
|
||||
hidden_size = head_num_q * head_size
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
|
||||
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
|
||||
bias = torch.randn(hidden_size, dtype=dtype).mlu()
|
||||
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
|
||||
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
|
||||
q = qkv[:, :, :head_num_q, :]
|
||||
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
|
||||
v = qkv[:, :, head_num_q+head_num_kv:, :]
|
||||
output1, output2 = flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
|
||||
bias, softmax_scale, is_causal, world_size)
|
||||
args = world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, output1
|
||||
mp.spawn(tp_flash_attn_sq_mm, args, nprocs=world_size, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFlashAttnSqMMAllreduce))
|
||||
109
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py
Executable file
109
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py
Executable file
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from itertools import product
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, world_size, block_m = args
|
||||
assert mat_k % world_size == 0, f"mat_k{mat_k} must be divisible by tp{world_size}"
|
||||
block_k = mat_k // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
alpha = 0.625
|
||||
beta = 1.0 if has_res else 0.
|
||||
input = torch.randn((mat_m, mat_k), dtype=dtype, device='mlu')
|
||||
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') if has_res else None
|
||||
weight = torch.randn((mat_n, mat_k), dtype=dtype, device="mlu")
|
||||
bias = torch.randn(mat_n, dtype=dtype, device="mlu") if has_bias else None
|
||||
pt_output = torch.matmul(input, weight.permute(1, 0))
|
||||
if has_bias:
|
||||
pt_output = pt_output + bias
|
||||
pt_output *= alpha
|
||||
if has_res:
|
||||
pt_output += beta * residual
|
||||
if has_bias:
|
||||
bias *= alpha
|
||||
input = input[..., start_k:end_k].contiguous()
|
||||
weight = weight[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
beta = beta if (residual is not None and rank == 0) else 0.
|
||||
output = ops.matmul_allreduce(cncl_comm, input, weight, bias, residual, alpha, beta, block_m)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
class TestMatMulAllReduceOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce due to ASan issues")
|
||||
def test_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
mat_m_list = [32]
|
||||
mat_n_list = [256]
|
||||
mat_k_list = [1680]
|
||||
has_res_list = [False, True]
|
||||
has_bias_list = [False, True]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
block_m = 4
|
||||
args = product(mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, dtype_list)
|
||||
|
||||
for mat_m, mat_n, mat_k, has_bias, has_res, dtype in args:
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp={}, block_m={}, testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m]
|
||||
mp.spawn(mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skip("not test")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce_random due to ASan issues")
|
||||
def test_matmul_allreduce_random(self):
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for i in range(10):
|
||||
tp_num = random.randint(1, device_n)
|
||||
mat_m = random.randint(1, 4096)
|
||||
mat_n = random.choice([512, 1024, 2048, 4096])
|
||||
k_start = (1024 // tp_num) * tp_num
|
||||
mat_k = random.randrange(k_start, 10240, tp_num)
|
||||
has_res = random.choice([False, True])
|
||||
has_bias = random.choice([False, True])
|
||||
dtype = random.choice(dtype_list)
|
||||
block_m = random.randint(1, 10)
|
||||
assert mat_k % tp_num == 0, f"mat_k{mat_k} must be divisible by tp_num{tp_num}"
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m), flush=True)
|
||||
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m]
|
||||
mp.spawn(mm_split_k, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestMatMulAllReduceOp))
|
||||
305
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py
Executable file
305
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py
Executable file
@@ -0,0 +1,305 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from typing import Union, List, Tuple
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.multiprocessing as mp
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def moe_split_inner_size(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
|
||||
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
|
||||
block_k = inner_size // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
scale_s = 0.01 # avoid the occurrence of inf
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=dtype) * scale_s
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=dtype)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=dtype) * scale_s
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=dtype)
|
||||
|
||||
weight1 = weight1.view(expert_num, -1, inner_size, hidden_size)
|
||||
weight1 = weight1[:, :, start_k:end_k]
|
||||
weight1 = weight1.reshape(expert_num, -1, hidden_size).contiguous()
|
||||
weight2 = weight2[..., start_k:end_k].contiguous()
|
||||
if bias1 is not None:
|
||||
bias1 = bias1.view(expert_num, -1, inner_size)
|
||||
bias1 = bias1[..., start_k:end_k]
|
||||
bias1 = bias1.reshape(expert_num, -1).contiguous()
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
|
||||
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode]
|
||||
output = ops.fused_moe(*param)
|
||||
all_reduce(output, ReduceOp.SUM, group=pg)
|
||||
|
||||
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output1 = ops.fused_moe(*param)
|
||||
|
||||
new_inner_size = weight2.shape[-1]
|
||||
block_e = 4096 // new_inner_size
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
param = [hidden_states, router_logit, weight1, weight2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output2 = ops.fused_moe(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
def sq_moe_split_inner_size(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
|
||||
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
|
||||
block_k = inner_size // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
scale_s = 0.1 # avoid the occurrence of inf
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
bias1, bias2 = None, None
|
||||
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=dtype)
|
||||
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=dtype)
|
||||
weight2 = torch.normal(0, 0.01, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=dtype)
|
||||
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
||||
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
quant_w1 = quant_w1.view(expert_num, -1, inner_size, hidden_size)
|
||||
quant_w1 = quant_w1[:, :, start_k:end_k]
|
||||
quant_w1 = quant_w1.reshape(expert_num, -1, hidden_size).contiguous()
|
||||
quant_w2 = quant_w2[..., start_k:end_k].contiguous()
|
||||
|
||||
if bias1 is not None:
|
||||
bias1 = bias1.view(expert_num, -1, inner_size)
|
||||
bias1 = bias1[..., start_k:end_k]
|
||||
bias1 = bias1.reshape(expert_num, -1).contiguous()
|
||||
if w1_scale is not None:
|
||||
w1_scale = w1_scale.view(expert_num, -1, inner_size)
|
||||
w1_scale = w1_scale[..., start_k:end_k]
|
||||
w1_scale = w1_scale.reshape(expert_num, -1).contiguous()
|
||||
if act_smooth is not None:
|
||||
act_smooth = act_smooth[..., start_k:end_k].contiguous()
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
|
||||
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode]
|
||||
output = ops.fused_moe(*param)
|
||||
all_reduce(output, ReduceOp.SUM, group=pg)
|
||||
|
||||
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0,
|
||||
block_n, cncl_comm]
|
||||
output1 = ops.fused_moe(*param)
|
||||
|
||||
new_inner_size = quant_w2.shape[-1]
|
||||
block_e = 4096 // new_inner_size
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
param = [hidden_states, router_logit, quant_w1,
|
||||
quant_w2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk,
|
||||
renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output2 = ops.fused_moe(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
class TestFusedMOEAllReduceOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_fused_moe_allreduce due to ASan issues")
|
||||
def test_single_fused_moe_allreduce(self):
|
||||
print("test_single_fused_moe_allreduce")
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_n = 2
|
||||
|
||||
batch, seq, hidden_size, inner_size = 1, 1024, 8192, 8192
|
||||
expert_num, topk, gated, renormalize, act_mode, dtype = 8, 2, True, True, 'silu', torch.float16
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
|
||||
mp.spawn(moe_split_inner_size, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_fused_moe_allreduce due to ASan issues")
|
||||
def test_random_fused_moe_allreduce(self):
|
||||
print("test_random_fused_moe_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
while(len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 1024, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
k_start = (512 // tp_num) * tp_num
|
||||
inner_size = random.randrange(k_start, 1024, tp_num * 2)
|
||||
expert_num = random.randint(1, 32)
|
||||
topk = random.randint(1,expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice(dtype_list)
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_sq_fused_moe_allreduce due to ASan issues")
|
||||
def test_sq_fused_moe_allreduce(self):
|
||||
print("test_sq_fused_moe_allreduce")
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_n = 2
|
||||
batch, seq, hidden_size, inner_size = 5, 9, 8192, 8192
|
||||
expert_num, topk, gated, renormalize, act_mode, dtype = 20, 16, False, True, 'gelu', torch.float16
|
||||
print(f"sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_allreduce due to ASan issues")
|
||||
def test_random_sq_fused_moe_allreduce(self):
|
||||
print("test_random_sq_fused_moe_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while(len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(256, 512, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
k_start = (512 // tp_num) * tp_num
|
||||
inner_size = random.randrange(k_start, 1024, tp_num * 2)
|
||||
expert_num = random.randint(1, 32)
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice([ torch.float16])
|
||||
if torch_mlu.mlu.get_device_name() == 'MLU370':
|
||||
dtype = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_with_4D_w2_allreduce due to ASan issues")
|
||||
def test_random_sq_fused_moe_with_4D_w2_allreduce(self):
|
||||
print("test_random_sq_fused_moe_with_4D_w2_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while (len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 4096, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
inner_size_per_tp = random.choice([256, 512])
|
||||
inner_size = inner_size_per_tp * tp_num
|
||||
expert_num_base = random.randint(1, 4)
|
||||
expert_num_factor = 4096 // inner_size_per_tp
|
||||
expert_num = expert_num_base * expert_num_factor
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice([ torch.float16])
|
||||
if torch_mlu.mlu.get_device_name() == 'MLU370':
|
||||
dtype = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFusedMOEAllReduceOp))
|
||||
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def compute_weight_only_scale(weight, quant_bit):
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
weight_max = torch.max(torch.abs(weight), axis=1, keepdims=True)
|
||||
weight_scale = torch.div(int_max, weight_max[0])
|
||||
weight_int = torch.mul(weight, weight_scale)
|
||||
weight_int = weight_int.type(torch.int8)
|
||||
weight_scale_recip = torch.div(weight_max[0], int_max).type(torch.float).squeeze()
|
||||
return weight_int, weight_scale_recip
|
||||
|
||||
def quant_mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
M, N, K, has_bias, has_res, dtype, block_m, world_size = args
|
||||
assert K % world_size == 0
|
||||
block_k = K // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
a = torch.randn(M, K, device="mlu", dtype=dtype)
|
||||
b = torch.randn(N, K, device="mlu", dtype=dtype)
|
||||
b_int, gemm_output_scale = compute_weight_only_scale(b, 8)
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
|
||||
torch_quant_matmul = QuantMatmul(b_int, None, None, None, None, gemm_output_scale, dtype)
|
||||
pt_output = torch_quant_matmul(a).detach()
|
||||
|
||||
a = a[..., start_k:end_k].contiguous()
|
||||
b_int = b_int[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
c = c if (c is not None and rank == 0) else None
|
||||
|
||||
param = [cncl_comm, a, None, None, b_int, None, None, bias, c, None,
|
||||
None, gemm_output_scale, None, "half", "weight_only", "quantize_none",
|
||||
"quantize_per_channel", 8, 1.0, 1.0, False, True, block_m]
|
||||
# beta = beta if (c is not None and rank == 0) else 0.
|
||||
output = ops._ops.quant_matmul_allreduce(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
def sq_mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
M, N, K, has_bias, has_res, dtype, world_size, block_m = args
|
||||
assert K % world_size == 0
|
||||
block_k = K // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
a = torch.randint(-10, 10, (M, K), dtype=torch.int8).mlu()
|
||||
b = torch.randint(-10, 10, (N, K), dtype=torch.int8).mlu()
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
|
||||
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
|
||||
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
|
||||
torch_quant_matmul = QuantMatmul(b, bias, c, a_scale, b_scale, None, dtype)
|
||||
pt_output = torch_quant_matmul(a).detach()
|
||||
|
||||
a = a[..., start_k:end_k].contiguous()
|
||||
b = b[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
c = c if (c is not None and rank == 0) else None
|
||||
# beta = beta if (c is not None and rank == 0) else 0.
|
||||
param = [cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, 1.0, 1.0, block_m]
|
||||
output = ops.smooth_quant_matmul_allreduce(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
class TestGptQuantMatmulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
# weight only, no bias, no residual, no activation
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_weight_only_matmul_allreduce due to ASan issues")
|
||||
def test_single_weight_only_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_m = 8
|
||||
M, K, N = 32, 256, 128
|
||||
has_bias = False
|
||||
has_res = False
|
||||
dtype = torch.half
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
|
||||
param = [M, N, K, has_bias, has_res, dtype, block_m, device_n]
|
||||
mp.spawn(quant_mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_sq_matmul_allreduce due to ASan issues")
|
||||
def test_single_sq_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_m = 1
|
||||
M, K, N = 2598, 1024, 1024
|
||||
has_bias = False
|
||||
has_res = False
|
||||
dtype = torch.half
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
|
||||
param = [M, N, K, has_bias, has_res, dtype, device_n, block_m]
|
||||
mp.spawn(sq_mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skip("not test")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_mm_allreduce due to ASan issues")
|
||||
def test_random_sq_mm_allreduce(self):
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
for i in range(10):
|
||||
tp_num = random.randint(1, device_n)
|
||||
M = random.randint(1, 4096)
|
||||
N = random.choice([512, 1024, 2048, 4096])
|
||||
k_start = (1024 // tp_num) * tp_num
|
||||
K = random.randrange(k_start, 2048, tp_num)
|
||||
has_res = random.choice([False, True])
|
||||
has_bias = random.choice([False, True])
|
||||
dtype = random.choice(dtype_list)
|
||||
block_m = random.randint(1, 10)
|
||||
|
||||
print("M={}, N={}, K={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, tp_num, block_m), flush=True)
|
||||
param = [M, N, K, has_bias, has_res, dtype, tp_num, block_m]
|
||||
mp.spawn(sq_mm_split_k, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestGptQuantMatmulOp))
|
||||
Reference in New Issue
Block a user