add ops
This commit is contained in:
17
torch_mlu_ops-v1.3.2/tests/ops_pytest/README.md
Normal file
17
torch_mlu_ops-v1.3.2/tests/ops_pytest/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
## BT_OPS测试脚本使用方式
|
||||
|
||||
```bash
|
||||
# 测试所有测例
|
||||
bash run_test.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
# 测试单个测例
|
||||
python3 test_测例名称.py
|
||||
```
|
||||
|
||||
- 必须在Torch-MLU-Ops docker容器内运行。
|
||||
|
||||
- 测试脚本的命名规则为 `test_测例名称.py`。
|
||||
|
||||
- 必须保证 Torch-MLU-Ops whl包正确安装。
|
||||
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import math
|
||||
import torch_mlu_ops as ops
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
|
||||
bias, softmax_scale, is_causal, world_size = 1):
|
||||
q_list = q.chunk(world_size, dim=2)
|
||||
k_list = k.chunk(world_size, dim=2)
|
||||
v_list = v.chunk(world_size, dim=2)
|
||||
smooth_list = smooth.chunk(world_size, dim=0)
|
||||
quant_weight_list = quant_weight.chunk(world_size, dim=1)
|
||||
quant_weight_list = [w.contiguous() for w in quant_weight_list]
|
||||
output1 = torch.zeros(q.size(0) * q.size(1), q.size(2) * q.size(3), dtype=q.dtype).mlu()
|
||||
for i in range(world_size):
|
||||
attn_output = ops.flash_attention(q_list[i], k_list[i], v_list[i], None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth_list[i], None)
|
||||
output1 += ops.smooth_quant_matmul(quant_input, input_scale,
|
||||
quant_weight_list[i], weight_scale, q.dtype, bias if i == 0 else None)
|
||||
attn_output = ops.flash_attention(q, k, v, None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
output2 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
return output1, output2
|
||||
|
||||
def tp_flash_attn_sq_mm(rank, *args):
|
||||
world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, base_output = args
|
||||
q_cpu, k_cpu, v_cpu, smooth_cpu = q.cpu(), k.cpu(), v.cpu(), smooth.cpu()
|
||||
quant_weight_cpu, weight_scale_cpu, bias_cpu = quant_weight.cpu(), weight_scale.cpu(), bias.cpu()
|
||||
base_output_cpu = base_output.cpu()
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
head_num_q = q.size(2)
|
||||
head_num_kv = k.size(2)
|
||||
seq = q.size(1)
|
||||
assert head_num_q % world_size == 0
|
||||
assert head_num_kv % world_size == 0
|
||||
head_num_q_tp = head_num_q // world_size
|
||||
head_num_kv_tp = head_num_kv // world_size
|
||||
q = q_cpu.mlu()
|
||||
k = k_cpu.mlu()
|
||||
v = v_cpu.mlu()
|
||||
smooth = smooth_cpu.mlu()
|
||||
quant_weight = quant_weight_cpu.mlu()
|
||||
weight_scale = weight_scale_cpu.mlu()
|
||||
# Note: only tp0 add bias
|
||||
bias = bias_cpu.mlu() if rank == 0 else None
|
||||
q_list = q.chunk(world_size, dim=2)
|
||||
k_list = k.chunk(world_size, dim=2)
|
||||
v_list = v.chunk(world_size, dim=2)
|
||||
smooth_list = smooth.chunk(world_size, dim=0)
|
||||
quant_weight_list = quant_weight.chunk(world_size, dim=1)
|
||||
quant_weight_list = [w.contiguous() for w in quant_weight_list]
|
||||
# test pad mode
|
||||
output_pad = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_list[rank], k_list[rank], v_list[rank],
|
||||
None, None, None, None,
|
||||
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
|
||||
seq, softmax_scale, is_causal)
|
||||
assertTensorsEqual(output_pad.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
# test pack mode
|
||||
cu_seq_lens_q = torch.tensor([0, seq], dtype=torch.int32).mlu()
|
||||
cu_seq_lens_k = torch.tensor([0, seq], dtype=torch.int32).mlu()
|
||||
q_pack = q_list[rank].flatten(0, 1)
|
||||
k_pack = k_list[rank].flatten(0, 1)
|
||||
v_pack = v_list[rank].flatten(0, 1)
|
||||
output_pack = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_pack, k_pack, v_pack,
|
||||
cu_seq_lens_q, cu_seq_lens_k, None, None,
|
||||
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
|
||||
seq, softmax_scale, is_causal)
|
||||
assertTensorsEqual(output_pack.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
class TestFlashAttnSqMMAllreduce(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_flash_attn_sq_mm_split_seq due to ASan issues")
|
||||
def test_flash_attn_sq_mm_split_seq(self):
|
||||
batch, seq, head_num_q, head_num_kv, head_size, is_causal, block_seq = 16, 1024, 8, 1, 128, True, 4
|
||||
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
|
||||
hidden_size = head_num_q * head_size
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
|
||||
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
|
||||
bias = torch.zeros(hidden_size, dtype=dtype).mlu()
|
||||
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
|
||||
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
|
||||
q = qkv[:, :, :head_num_q, :]
|
||||
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
|
||||
v = qkv[:, :, head_num_q+head_num_kv:, :]
|
||||
attn_output = ops.flash_attention(q, k, v, None, None, None,
|
||||
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
output1 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
output_list = []
|
||||
block_size = seq // block_seq
|
||||
for i in range(block_seq):
|
||||
start = i * block_size
|
||||
end = seq if i == block_seq - 1 else (i + 1) * block_size
|
||||
attn_output = ops.flash_attention(q[:,start:end], k[:,:end], v[:,:end], None, None, None,
|
||||
None, None, end - start, end, softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
|
||||
out = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
|
||||
output_list.append(out)
|
||||
output2 = torch.cat(output_list, dim=0)
|
||||
output2 = output2.reshape(block_seq, batch, block_size, hidden_size).transpose(0, 1).reshape(batch*seq, hidden_size)
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.0045, use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_flash_attn_sq_mm_allreduce(self):
|
||||
world_size = min(torch_mlu.mlu.device_count(), 8)
|
||||
batch, seq, head_num_q, head_num_kv, head_size, is_causal = 1, 8192, 64, 8, 128, True
|
||||
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
|
||||
for i in range(1):
|
||||
# seq = random.randint(1, 32768)
|
||||
# is_causal = bool(random.randint(0, 1))
|
||||
print("=============test[{}]: seq = {}, causal = {}===============".format(i, seq, is_causal))
|
||||
hidden_size = head_num_q * head_size
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
|
||||
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
|
||||
bias = torch.randn(hidden_size, dtype=dtype).mlu()
|
||||
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
|
||||
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
|
||||
q = qkv[:, :, :head_num_q, :]
|
||||
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
|
||||
v = qkv[:, :, head_num_q+head_num_kv:, :]
|
||||
output1, output2 = flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
|
||||
bias, softmax_scale, is_causal, world_size)
|
||||
args = world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, output1
|
||||
mp.spawn(tp_flash_attn_sq_mm, args, nprocs=world_size, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFlashAttnSqMMAllreduce))
|
||||
109
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py
Executable file
109
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py
Executable file
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from itertools import product
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, world_size, block_m = args
|
||||
assert mat_k % world_size == 0, f"mat_k{mat_k} must be divisible by tp{world_size}"
|
||||
block_k = mat_k // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
alpha = 0.625
|
||||
beta = 1.0 if has_res else 0.
|
||||
input = torch.randn((mat_m, mat_k), dtype=dtype, device='mlu')
|
||||
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') if has_res else None
|
||||
weight = torch.randn((mat_n, mat_k), dtype=dtype, device="mlu")
|
||||
bias = torch.randn(mat_n, dtype=dtype, device="mlu") if has_bias else None
|
||||
pt_output = torch.matmul(input, weight.permute(1, 0))
|
||||
if has_bias:
|
||||
pt_output = pt_output + bias
|
||||
pt_output *= alpha
|
||||
if has_res:
|
||||
pt_output += beta * residual
|
||||
if has_bias:
|
||||
bias *= alpha
|
||||
input = input[..., start_k:end_k].contiguous()
|
||||
weight = weight[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
beta = beta if (residual is not None and rank == 0) else 0.
|
||||
output = ops.matmul_allreduce(cncl_comm, input, weight, bias, residual, alpha, beta, block_m)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
class TestMatMulAllReduceOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce due to ASan issues")
|
||||
def test_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
mat_m_list = [32]
|
||||
mat_n_list = [256]
|
||||
mat_k_list = [1680]
|
||||
has_res_list = [False, True]
|
||||
has_bias_list = [False, True]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
block_m = 4
|
||||
args = product(mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, dtype_list)
|
||||
|
||||
for mat_m, mat_n, mat_k, has_bias, has_res, dtype in args:
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp={}, block_m={}, testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m]
|
||||
mp.spawn(mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skip("not test")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce_random due to ASan issues")
|
||||
def test_matmul_allreduce_random(self):
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for i in range(10):
|
||||
tp_num = random.randint(1, device_n)
|
||||
mat_m = random.randint(1, 4096)
|
||||
mat_n = random.choice([512, 1024, 2048, 4096])
|
||||
k_start = (1024 // tp_num) * tp_num
|
||||
mat_k = random.randrange(k_start, 10240, tp_num)
|
||||
has_res = random.choice([False, True])
|
||||
has_bias = random.choice([False, True])
|
||||
dtype = random.choice(dtype_list)
|
||||
block_m = random.randint(1, 10)
|
||||
assert mat_k % tp_num == 0, f"mat_k{mat_k} must be divisible by tp_num{tp_num}"
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m), flush=True)
|
||||
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m]
|
||||
mp.spawn(mm_split_k, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestMatMulAllReduceOp))
|
||||
305
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py
Executable file
305
torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py
Executable file
@@ -0,0 +1,305 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from typing import Union, List, Tuple
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.multiprocessing as mp
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def moe_split_inner_size(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
|
||||
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
|
||||
block_k = inner_size // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
scale_s = 0.01 # avoid the occurrence of inf
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=dtype) * scale_s
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=dtype)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=dtype) * scale_s
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=dtype)
|
||||
|
||||
weight1 = weight1.view(expert_num, -1, inner_size, hidden_size)
|
||||
weight1 = weight1[:, :, start_k:end_k]
|
||||
weight1 = weight1.reshape(expert_num, -1, hidden_size).contiguous()
|
||||
weight2 = weight2[..., start_k:end_k].contiguous()
|
||||
if bias1 is not None:
|
||||
bias1 = bias1.view(expert_num, -1, inner_size)
|
||||
bias1 = bias1[..., start_k:end_k]
|
||||
bias1 = bias1.reshape(expert_num, -1).contiguous()
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
|
||||
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode]
|
||||
output = ops.fused_moe(*param)
|
||||
all_reduce(output, ReduceOp.SUM, group=pg)
|
||||
|
||||
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output1 = ops.fused_moe(*param)
|
||||
|
||||
new_inner_size = weight2.shape[-1]
|
||||
block_e = 4096 // new_inner_size
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
param = [hidden_states, router_logit, weight1, weight2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual,
|
||||
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output2 = ops.fused_moe(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
def sq_moe_split_inner_size(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
|
||||
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
|
||||
block_k = inner_size // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
scale_s = 0.1 # avoid the occurrence of inf
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
|
||||
bias1, bias2 = None, None
|
||||
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=dtype)
|
||||
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=dtype)
|
||||
weight2 = torch.normal(0, 0.01, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=dtype)
|
||||
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
||||
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
quant_w1 = quant_w1.view(expert_num, -1, inner_size, hidden_size)
|
||||
quant_w1 = quant_w1[:, :, start_k:end_k]
|
||||
quant_w1 = quant_w1.reshape(expert_num, -1, hidden_size).contiguous()
|
||||
quant_w2 = quant_w2[..., start_k:end_k].contiguous()
|
||||
|
||||
if bias1 is not None:
|
||||
bias1 = bias1.view(expert_num, -1, inner_size)
|
||||
bias1 = bias1[..., start_k:end_k]
|
||||
bias1 = bias1.reshape(expert_num, -1).contiguous()
|
||||
if w1_scale is not None:
|
||||
w1_scale = w1_scale.view(expert_num, -1, inner_size)
|
||||
w1_scale = w1_scale[..., start_k:end_k]
|
||||
w1_scale = w1_scale.reshape(expert_num, -1).contiguous()
|
||||
if act_smooth is not None:
|
||||
act_smooth = act_smooth[..., start_k:end_k].contiguous()
|
||||
residual = residual if (residual is not None and rank == 0) else None
|
||||
|
||||
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode]
|
||||
output = ops.fused_moe(*param)
|
||||
all_reduce(output, ReduceOp.SUM, group=pg)
|
||||
|
||||
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0,
|
||||
block_n, cncl_comm]
|
||||
output1 = ops.fused_moe(*param)
|
||||
|
||||
new_inner_size = quant_w2.shape[-1]
|
||||
block_e = 4096 // new_inner_size
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
param = [hidden_states, router_logit, quant_w1,
|
||||
quant_w2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk,
|
||||
renormalize, gated, act_mode, 0, block_n, cncl_comm]
|
||||
output2 = ops.fused_moe(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
|
||||
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
class TestFusedMOEAllReduceOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_fused_moe_allreduce due to ASan issues")
|
||||
def test_single_fused_moe_allreduce(self):
|
||||
print("test_single_fused_moe_allreduce")
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_n = 2
|
||||
|
||||
batch, seq, hidden_size, inner_size = 1, 1024, 8192, 8192
|
||||
expert_num, topk, gated, renormalize, act_mode, dtype = 8, 2, True, True, 'silu', torch.float16
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
|
||||
mp.spawn(moe_split_inner_size, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_fused_moe_allreduce due to ASan issues")
|
||||
def test_random_fused_moe_allreduce(self):
|
||||
print("test_random_fused_moe_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
while(len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 1024, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
k_start = (512 // tp_num) * tp_num
|
||||
inner_size = random.randrange(k_start, 1024, tp_num * 2)
|
||||
expert_num = random.randint(1, 32)
|
||||
topk = random.randint(1,expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice(dtype_list)
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_sq_fused_moe_allreduce due to ASan issues")
|
||||
def test_sq_fused_moe_allreduce(self):
|
||||
print("test_sq_fused_moe_allreduce")
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_n = 2
|
||||
batch, seq, hidden_size, inner_size = 5, 9, 8192, 8192
|
||||
expert_num, topk, gated, renormalize, act_mode, dtype = 20, 16, False, True, 'gelu', torch.float16
|
||||
print(f"sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_allreduce due to ASan issues")
|
||||
def test_random_sq_fused_moe_allreduce(self):
|
||||
print("test_random_sq_fused_moe_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while(len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(256, 512, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
k_start = (512 // tp_num) * tp_num
|
||||
inner_size = random.randrange(k_start, 1024, tp_num * 2)
|
||||
expert_num = random.randint(1, 32)
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice([ torch.float16])
|
||||
if torch_mlu.mlu.get_device_name() == 'MLU370':
|
||||
dtype = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_with_4D_w2_allreduce due to ASan issues")
|
||||
def test_random_sq_fused_moe_with_4D_w2_allreduce(self):
|
||||
print("test_random_sq_fused_moe_with_4D_w2_allreduce")
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = min(torch_mlu.mlu.device_count(), 8)
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while (len(case_list) < 10):
|
||||
block_n = random.randint(-5, 5)
|
||||
tp_num = random.randint(1, device_n)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 4096, 2)
|
||||
if block_n != 0:
|
||||
if hidden_size // abs(block_n) < 256:
|
||||
continue
|
||||
hidden_size = (hidden_size // block_n) * block_n
|
||||
else:
|
||||
hidden_size = 1024 * random.randint(1, 10)
|
||||
inner_size_per_tp = random.choice([256, 512])
|
||||
inner_size = inner_size_per_tp * tp_num
|
||||
expert_num_base = random.randint(1, 4)
|
||||
expert_num_factor = 4096 // inner_size_per_tp
|
||||
expert_num = expert_num_base * expert_num_factor
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
dtype = random.choice([ torch.float16])
|
||||
if torch_mlu.mlu.get_device_name() == 'MLU370':
|
||||
dtype = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
|
||||
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
|
||||
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
|
||||
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFusedMOEAllReduceOp))
|
||||
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
work_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.dirname(work_dir))
|
||||
from common_utils import *
|
||||
|
||||
def compute_weight_only_scale(weight, quant_bit):
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
weight_max = torch.max(torch.abs(weight), axis=1, keepdims=True)
|
||||
weight_scale = torch.div(int_max, weight_max[0])
|
||||
weight_int = torch.mul(weight, weight_scale)
|
||||
weight_int = weight_int.type(torch.int8)
|
||||
weight_scale_recip = torch.div(weight_max[0], int_max).type(torch.float).squeeze()
|
||||
return weight_int, weight_scale_recip
|
||||
|
||||
def quant_mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
M, N, K, has_bias, has_res, dtype, block_m, world_size = args
|
||||
assert K % world_size == 0
|
||||
block_k = K // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
a = torch.randn(M, K, device="mlu", dtype=dtype)
|
||||
b = torch.randn(N, K, device="mlu", dtype=dtype)
|
||||
b_int, gemm_output_scale = compute_weight_only_scale(b, 8)
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
|
||||
torch_quant_matmul = QuantMatmul(b_int, None, None, None, None, gemm_output_scale, dtype)
|
||||
pt_output = torch_quant_matmul(a).detach()
|
||||
|
||||
a = a[..., start_k:end_k].contiguous()
|
||||
b_int = b_int[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
c = c if (c is not None and rank == 0) else None
|
||||
|
||||
param = [cncl_comm, a, None, None, b_int, None, None, bias, c, None,
|
||||
None, gemm_output_scale, None, "half", "weight_only", "quantize_none",
|
||||
"quantize_per_channel", 8, 1.0, 1.0, False, True, block_m]
|
||||
# beta = beta if (c is not None and rank == 0) else 0.
|
||||
output = ops._ops.quant_matmul_allreduce(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
def sq_mm_split_k(rank, *args):
|
||||
torch.manual_seed(0)
|
||||
M, N, K, has_bias, has_res, dtype, world_size, block_m = args
|
||||
assert K % world_size == 0
|
||||
block_k = K // world_size
|
||||
start_k = rank * block_k
|
||||
end_k = start_k + block_k
|
||||
|
||||
setup(rank, world_size)
|
||||
pg = get_default_group()
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
a = torch.randint(-10, 10, (M, K), dtype=torch.int8).mlu()
|
||||
b = torch.randint(-10, 10, (N, K), dtype=torch.int8).mlu()
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
|
||||
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
|
||||
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
|
||||
torch_quant_matmul = QuantMatmul(b, bias, c, a_scale, b_scale, None, dtype)
|
||||
pt_output = torch_quant_matmul(a).detach()
|
||||
|
||||
a = a[..., start_k:end_k].contiguous()
|
||||
b = b[..., start_k:end_k].contiguous()
|
||||
bias = bias if (bias is not None and rank == 0) else None
|
||||
c = c if (c is not None and rank == 0) else None
|
||||
# beta = beta if (c is not None and rank == 0) else 0.
|
||||
param = [cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, 1.0, 1.0, block_m]
|
||||
output = ops.smooth_quant_matmul_allreduce(*param)
|
||||
if rank == 0:
|
||||
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
|
||||
cleanup()
|
||||
|
||||
class TestGptQuantMatmulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
# weight only, no bias, no residual, no activation
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_weight_only_matmul_allreduce due to ASan issues")
|
||||
def test_single_weight_only_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_m = 8
|
||||
M, K, N = 32, 256, 128
|
||||
has_bias = False
|
||||
has_res = False
|
||||
dtype = torch.half
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
|
||||
param = [M, N, K, has_bias, has_res, dtype, block_m, device_n]
|
||||
mp.spawn(quant_mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_sq_matmul_allreduce due to ASan issues")
|
||||
def test_single_sq_matmul_allreduce(self):
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
block_m = 1
|
||||
M, K, N = 2598, 1024, 1024
|
||||
has_bias = False
|
||||
has_res = False
|
||||
dtype = torch.half
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
|
||||
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
|
||||
param = [M, N, K, has_bias, has_res, dtype, device_n, block_m]
|
||||
mp.spawn(sq_mm_split_k, param, nprocs=device_n, join=True)
|
||||
|
||||
@unittest.skip("not test")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_mm_allreduce due to ASan issues")
|
||||
def test_random_sq_mm_allreduce(self):
|
||||
import random
|
||||
random.seed(0)
|
||||
device_n = torch_mlu.mlu.device_count()
|
||||
if device_n < 2:
|
||||
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
|
||||
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
for i in range(10):
|
||||
tp_num = random.randint(1, device_n)
|
||||
M = random.randint(1, 4096)
|
||||
N = random.choice([512, 1024, 2048, 4096])
|
||||
k_start = (1024 // tp_num) * tp_num
|
||||
K = random.randrange(k_start, 2048, tp_num)
|
||||
has_res = random.choice([False, True])
|
||||
has_bias = random.choice([False, True])
|
||||
dtype = random.choice(dtype_list)
|
||||
block_m = random.randint(1, 10)
|
||||
|
||||
print("M={}, N={}, K={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
|
||||
M, N, K, has_bias, has_res, dtype, tp_num, block_m), flush=True)
|
||||
param = [M, N, K, has_bias, has_res, dtype, tp_num, block_m]
|
||||
mp.spawn(sq_mm_split_k, param, nprocs=tp_num, join=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestGptQuantMatmulOp))
|
||||
734
torch_mlu_ops-v1.3.2/tests/ops_pytest/common_utils.py
Normal file
734
torch_mlu_ops-v1.3.2/tests/ops_pytest/common_utils.py
Normal file
@@ -0,0 +1,734 @@
|
||||
import sys
|
||||
sys_args = sys.argv
|
||||
sys.argv = [sys_args.pop(0)] # prevent unittest printing help info
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch_mlu
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List, Tuple, Optional
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group as get_default_group, all_reduce, ReduceOp
|
||||
import torch.testing._internal.optests as optests
|
||||
import random
|
||||
import argparse
|
||||
from abc import abstractmethod, ABC
|
||||
import unittest
|
||||
import torch_mlu_ops as tmo
|
||||
import os
|
||||
|
||||
act_mode_dict = {"relu": torch.nn.functional.relu,
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"silu": torch.nn.functional.silu}
|
||||
|
||||
class BtTestCase(TestCase, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
os.environ['TORCH_ALLOW_TF32_CNMATMUL_OVERRIDE'] = '0'
|
||||
|
||||
@abstractmethod
|
||||
def op_impl_base(self, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def test_inductor(self):
|
||||
pass
|
||||
|
||||
def base_opcheck(self, interface_overload, args):
|
||||
target_check = ["test_schema", "test_autograd_registration"]
|
||||
if torch.__version__ >= '2.3.0':
|
||||
target_check.append("test_faketensor")
|
||||
target_status = {key: "SUCCESS" for key in target_check}
|
||||
result = optests.opcheck(interface_overload, args, test_utils=target_check)
|
||||
self.assertEqual(result, target_status,)
|
||||
|
||||
def assertException(self, error_msg, func, *args, **kwinputs):
|
||||
try:
|
||||
func(*args, **kwinputs)
|
||||
self.assertTrue(False)
|
||||
except Exception as e:
|
||||
if error_msg:
|
||||
self.assertTrue(error_msg == str(e))
|
||||
else:
|
||||
self.assertTrue(True)
|
||||
|
||||
def assertTensorsEqual(self,
|
||||
a,
|
||||
b,
|
||||
prec=None,
|
||||
message='',
|
||||
allow_inf=False,
|
||||
use_MSE=False,
|
||||
use_RAE=False,
|
||||
use_RMA=False):
|
||||
'''unittest.TestCase'''
|
||||
if a.dtype == torch.bool:
|
||||
a = a.float()
|
||||
if b.dtype == torch.bool:
|
||||
b = b.float()
|
||||
epsilon = 1.0 / 16384
|
||||
self.assertEqual(a.size(), b.size(), message)
|
||||
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
|
||||
if a.numel() > 0:
|
||||
# check that NaNs are in the same locations
|
||||
nan_mask = a != a
|
||||
self.assertTrue(torch.equal(nan_mask, b != b), message)
|
||||
diff = a - b
|
||||
diff[nan_mask] = 0
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
a[nan_mask] = 0
|
||||
b[nan_mask] = 0
|
||||
# inf check if allow_inf=True
|
||||
if allow_inf:
|
||||
inf_mask = (a == float("inf")) | (a == float("-inf"))
|
||||
self.assertTrue(torch.equal(inf_mask,
|
||||
(b == float("inf")) | (b == float("-inf"))),
|
||||
message)
|
||||
diff[inf_mask] = 0
|
||||
a[inf_mask] = 0
|
||||
b[inf_mask] = 0
|
||||
# TODO: implement abs on CharTensor
|
||||
if diff.is_signed() and 'CharTensor' not in diff.type():
|
||||
diff = diff.abs()
|
||||
if use_MSE:
|
||||
diff = diff.abs().pow(2).sum()
|
||||
a_pow_sum = a.pow(2).sum()
|
||||
if diff <= (2 * epsilon) * (2 * epsilon):
|
||||
diff = 0.0
|
||||
if a_pow_sum <= epsilon:
|
||||
a_pow_sum = a_pow_sum + epsilon
|
||||
diff = torch.div(diff, (a_pow_sum * 1.0))
|
||||
self.assertLessEqual(diff.sqrt(), prec, message)
|
||||
elif use_RAE:
|
||||
diff = diff.abs().sum()
|
||||
a_sum = a.abs().sum()
|
||||
if a_sum == 0:
|
||||
self.assertEqual(a, b, message)
|
||||
else:
|
||||
diff = torch.div(diff, a_sum)
|
||||
self.assertLessEqual(diff, prec, message)
|
||||
elif use_RMA:
|
||||
a_mean = a.abs().mean()
|
||||
b_mean = b.abs().mean()
|
||||
if a_mean == 0:
|
||||
self.assertEqual(a, b, message)
|
||||
else:
|
||||
diff = torch.div((a_mean - b_mean).abs(), a_mean)
|
||||
self.assertLessEqual(diff, prec, message)
|
||||
else:
|
||||
max_err = diff.max()
|
||||
self.assertLessEqual(max_err, prec, message)
|
||||
|
||||
def run_unittest(case) -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-k', nargs='+', type=str, default="", help='specify case to run')
|
||||
args = parser.parse_args(sys_args)
|
||||
if args.k != "":
|
||||
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromNames(args.k, case))
|
||||
else:
|
||||
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(case))
|
||||
return not ret.wasSuccessful()
|
||||
|
||||
class TMOTimer:
|
||||
def __init__(self, repeat: int = 1):
|
||||
self.repeat = repeat
|
||||
|
||||
def __enter__(self):
|
||||
self.notify_start = torch.mlu.Event(enable_timing=True)
|
||||
self.notify_end = torch.mlu.Event(enable_timing=True)
|
||||
self.notify_start.record()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.notify_end.record()
|
||||
self.notify_end.synchronize()
|
||||
total_hardware_time = self.notify_start.hardware_time(self.notify_end)
|
||||
self.average_hardware_time = total_hardware_time / self.repeat
|
||||
|
||||
def QuantByRow(input: torch.Tensor, quant_bit: int, group_num: int=1):
|
||||
input_shape = input.shape
|
||||
if input.dim() > 2:
|
||||
input = input.view(-1, input_shape[-1])
|
||||
if input.dim() == 1:
|
||||
input = input.unsqueeze(0)
|
||||
assert input.dim() == 2, "input must be 2-D tensor."
|
||||
assert quant_bit == 4 or quant_bit == 8, "quant_bit must be 4 or 8."
|
||||
assert group_num >= 1, "group_num >= 1."
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
int_min = -float(2 ** (quant_bit - 1))
|
||||
group_size = input.size(-1) // group_num
|
||||
input_v = input.view(input.size(0), group_num, group_size) if group_num > 1 else input
|
||||
max, _ = input_v.abs().max(dim=-1, keepdim=True)
|
||||
scale = max.to(torch.float) / int_max
|
||||
quant_input = (input_v / scale).round().clamp(int_min, int_max).to(torch.int8).view(input.size())
|
||||
return quant_input.view(input_shape), scale.squeeze(-1)
|
||||
|
||||
def QuantByTensor(input: torch.Tensor, quant_bit: int):
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
int_min = -float(2 ** (quant_bit - 1))
|
||||
input_max = torch.max(torch.abs(input))
|
||||
input_scale = int_max / input_max
|
||||
input_int = torch.mul(input, input_scale).round().clamp(int_min, int_max).to(torch.int8)
|
||||
return input_int, input_scale
|
||||
|
||||
def PairlyPackInt8(input):
|
||||
assert input.dtype == torch.int8, "dtype of input must be int8."
|
||||
assert input.dim() == 2 or input.dim() == 3, "input must be 2-D or 3-D tensor."
|
||||
assert input.size(-1) % 2 == 0, "size(-1) of input must be even."
|
||||
input_shape = list(input.shape)
|
||||
input_flat = input.flatten()
|
||||
d0 = input_flat[0::2].to(torch.uint8)
|
||||
d1 = input_flat[1::2].to(torch.uint8)
|
||||
dp = (d1 << 4) + (d0 & 0x0F)
|
||||
input_shape[-1] = input_shape[-1] // 2
|
||||
return dp.to(torch.int8).reshape(input_shape)
|
||||
|
||||
def UnpackInt4(input):
|
||||
assert input.dtype == torch.int8, "dtype of input must be int8."
|
||||
input_flat = input.flatten()
|
||||
n = input_flat.size(0)
|
||||
output = torch.zeros(n * 2, dtype=torch.int8, device=input.device)
|
||||
high = input_flat >> 4
|
||||
low = input_flat << 4
|
||||
low = low >> 4
|
||||
output[0::2] = low
|
||||
output[1::2] = high
|
||||
|
||||
return output
|
||||
|
||||
def smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias=None):
|
||||
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
|
||||
assert a_scale.dim() == 1, "a_scale.dim() == 1"
|
||||
assert a.size(0) == a_scale.size(0), "a.size(0) == a_scale.size(0)"
|
||||
assert b.size(0) == b_scale.size(-1), "b.size(0) == b_scale.size(-1)"
|
||||
m = a.size(0)
|
||||
n = b.size(0)
|
||||
a_k = a.size(1)
|
||||
b_k = b.size(1)
|
||||
if b_scale.dim() == 1:
|
||||
b_scale = b_scale.unsqueeze(0)
|
||||
quant_group = b_scale.size(0)
|
||||
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
|
||||
if a_k == b_k * 2:
|
||||
b = UnpackInt4(b)
|
||||
b = b.view(n, quant_group, -1).transpose(0, 1).contiguous()
|
||||
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
|
||||
for i in range(quant_group):
|
||||
scale_mn = torch.matmul(a_scale.unsqueeze(1), b_scale[i].unsqueeze(0)) # (m, 1) x (1, n) = (m, n)
|
||||
out += torch.einsum('mk,nk->mn', a[i].to(torch.float), b[i].to(torch.float)) * scale_mn
|
||||
# out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
|
||||
out = out.to(out_dtype)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return out
|
||||
|
||||
def smooth_quant_matmul_w4w8_mixed(a, a_scale, b, b_scale, out_dtype, bias=None, quant_flag=None):
|
||||
m = a.shape[0]
|
||||
k = a.shape[1]
|
||||
quant_group = b_scale.shape[0]
|
||||
group_wise = k // quant_group
|
||||
n = b_scale.shape[1]
|
||||
b = b.view(n, -1)
|
||||
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
|
||||
new_b = []
|
||||
start = 0
|
||||
end = 0
|
||||
for i in range(quant_group):
|
||||
if quant_flag[i] == 4:
|
||||
end += group_wise // 2
|
||||
new_b.append(UnpackInt4(b[:, start:end]).view(n, -1))
|
||||
else:
|
||||
end += group_wise
|
||||
new_b.append((b[:, start:end]))
|
||||
start = end
|
||||
new_b = torch.cat(new_b, 1)
|
||||
b = new_b.view(n, quant_group, -1).transpose(0, 1).contiguous()
|
||||
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
|
||||
for i in range(quant_group):
|
||||
out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
|
||||
out = out.to(out_dtype)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return out
|
||||
|
||||
def weight_only_quant_matmul(a, b, scale, bias=None):
|
||||
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
|
||||
assert scale.dim() == 1 or scale.dim() == 2, "scale.dim() == 1 or scale.dim() == 2"
|
||||
assert b.size(0) == scale.size(0), "b.size(0) == b_scale.size(0)"
|
||||
assert a.size(1) == b.size(1), "a.size(1) == b.size(1)"
|
||||
if scale.dim() == 2:
|
||||
group_size = b.size(1) // scale.size(1)
|
||||
scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape)
|
||||
else:
|
||||
scale_bd = scale.unsqueeze(-1)
|
||||
b1 = b * scale_bd
|
||||
out = torch.einsum('mk,nk->mn', a.to(torch.float), b1.to(torch.float)).to(a.dtype)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return out
|
||||
|
||||
def single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale,
|
||||
v_cache_quant_scale, alibi_slopes, window_size_left, window_size_right, softmax_scale, return_lse):
|
||||
q = q.float()
|
||||
k_cache = k_cache.float()
|
||||
v_cache = v_cache.float()
|
||||
def masked_attention(query, key, value, alibi_slope, context_len, window_size_left, window_size_right, qk_scale) -> torch.Tensor:
|
||||
# (num_heads, seq_q, seq_k)
|
||||
qk = torch.einsum('qhd,hkd->hqk', query, key)
|
||||
qk = qk * qk_scale
|
||||
if alibi_slope is not None:
|
||||
alibi_dist = torch.arange(0, context_len, dtype=torch.float32).mlu()
|
||||
alibi = alibi_slope[:, None] * alibi_dist
|
||||
qk = qk + alibi[:, None, :]
|
||||
|
||||
_, seq_q, seq_k = qk.size()
|
||||
if seq_q > 1: #causal mask
|
||||
ml = torch.zeros((seq_q, seq_k - seq_q), dtype=qk.dtype).mlu()
|
||||
ones = torch.ones((seq_q, seq_q), dtype=qk.dtype).mlu() * -torch.inf
|
||||
mr = torch.triu(ones, diagonal=1)
|
||||
mask = torch.cat((ml, mr), dim=-1)
|
||||
qk = qk + mask
|
||||
if window_size_left != -1 or window_size_right != -1:
|
||||
mask_w = torch.full((seq_q, seq_k), -torch.inf, dtype=torch.float, device="mlu")
|
||||
for qi in range(seq_q):
|
||||
left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0
|
||||
right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k
|
||||
mask_w[qi, left:right] = 0
|
||||
qk += mask_w
|
||||
attention = torch.softmax(qk, dim = -1, dtype=qk.dtype)
|
||||
qkv = torch.einsum('hqk,hkd->qhd', attention, value)
|
||||
return qkv, qk
|
||||
|
||||
if k_cache_quant_scale is not None and v_cache_quant_scale is not None:
|
||||
if k_cache_quant_scale.dim() == 2: # per_channel: [kv_head_num, head_size]
|
||||
k_cache_quant_scale = k_cache_quant_scale.reshape(1, k_cache_quant_scale.shape[0], 1, k_cache_quant_scale.shape[1])
|
||||
v_cache_quant_scale = v_cache_quant_scale.reshape(1, v_cache_quant_scale.shape[0], 1, v_cache_quant_scale.shape[1])
|
||||
elif k_cache_quant_scale.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
|
||||
k_cache_quant_scale = k_cache_quant_scale.reshape(*k_cache_quant_scale.shape, 1)
|
||||
v_cache_quant_scale = v_cache_quant_scale.reshape(*v_cache_quant_scale.shape, 1)
|
||||
k_cache *= k_cache_quant_scale
|
||||
v_cache *= v_cache_quant_scale
|
||||
bs, seq_q, num_heads, head_size = q.size()
|
||||
head_size_v = v_cache.size(-1)
|
||||
num_blocks, num_kv_heads, block_size, _ = k_cache.size()
|
||||
output = torch.zeros((bs, seq_q, num_heads, head_size_v), dtype=torch.float16)
|
||||
lse = torch.zeros((bs, num_heads, seq_q), dtype=torch.float)
|
||||
|
||||
assert (num_heads % num_kv_heads == 0)
|
||||
head_repeats = num_heads // num_kv_heads
|
||||
for bs_id in range(bs):
|
||||
q_bs = q[bs_id]
|
||||
context_len = int(context_lens[bs_id])
|
||||
if context_len == 0:
|
||||
output[bs_id] = torch.zeros((seq_q, num_heads, head_size_v), device = q.device, dtype=output.dtype)
|
||||
lse[bs_id] = lse[bs_id].fill_(-float('inf'))
|
||||
else :
|
||||
block_table = block_tables[bs_id]
|
||||
table_end = (context_len + block_size - 1) // block_size
|
||||
block_ids = block_table[0 : table_end]
|
||||
keys, values = k_cache[block_ids], v_cache[block_ids]
|
||||
|
||||
keys = torch.repeat_interleave(keys, head_repeats, dim=1)
|
||||
keys = keys.transpose(1, 0).contiguous().view(num_heads, -1, head_size)
|
||||
keys = keys[:, 0:context_len, :]
|
||||
|
||||
values = torch.repeat_interleave(values, head_repeats, dim=1)
|
||||
values = values.transpose(1, 0).contiguous().view(num_heads, -1, head_size_v)
|
||||
values = values[:, 0:context_len, :]
|
||||
|
||||
alibi_slope = alibi_slopes[bs_id] if alibi_slopes is not None else None
|
||||
qkv, qk= masked_attention(q_bs, keys, values, alibi_slope, context_len, window_size_left, window_size_right, softmax_scale)
|
||||
output[bs_id] = qkv
|
||||
lse[bs_id] = torch.logsumexp(qk, dim = -1)
|
||||
return (output, lse) if return_lse else output
|
||||
|
||||
|
||||
def update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs):
|
||||
# only pad
|
||||
is_pack = out.dim() == 3
|
||||
new_out, new_lse = out.clone(), lse.clone()
|
||||
batch, max_seq_len, block_seq_len = lse.shape[0], lse.shape[-1], block_lse.shape[-1]
|
||||
|
||||
lse_bsh = lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
new_lse_bsh = new_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
block_lse_bsh = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
|
||||
if not is_pack:
|
||||
for i in range(batch):
|
||||
out_seq_offset = 0 if seq_offsets is None else seq_offsets[i]
|
||||
out_i = out[i, out_seq_offset : out_seq_offset + block_seq_len]
|
||||
lse_i = lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len]
|
||||
block_out_i = block_out[i, :]
|
||||
block_lse_i = block_lse_bsh[i, :]
|
||||
new_out[i, out_seq_offset : out_seq_offset + block_seq_len] = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
|
||||
new_lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] = (lse_i - F.logsigmoid(lse_i - block_lse_i))
|
||||
else:
|
||||
for i in range(batch):
|
||||
block_i_begin = block_cu_seqs[i]
|
||||
block_i_end = block_cu_seqs[i + 1]
|
||||
block_i_lens = block_i_end - block_i_begin
|
||||
out_i_begin = cu_seqs[i]
|
||||
out_seq_offset = seq_offsets[i]
|
||||
|
||||
block_out_i = block_out[block_i_begin : block_i_end]
|
||||
block_lse_i = block_lse_bsh[i, 0 : block_i_lens]
|
||||
out_i = out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens]
|
||||
lse_i = lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens]
|
||||
new_out_i = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
|
||||
new_lse_i = (lse_i - F.logsigmoid(lse_i - block_lse_i))
|
||||
new_out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] = new_out_i
|
||||
new_lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] = new_lse_i
|
||||
|
||||
return (new_out, new_lse_bsh.squeeze(dim=-1).transpose(-2, -1))
|
||||
|
||||
class QuantMatmul(torch.nn.Module):
|
||||
def __init__(self, weight, bias, residual, input_scale, weight_scale, gemm_output_scale, dtype,
|
||||
alpha:float = 1.0, beta:float = 1.0, act_mode:str = 'none') -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.weight = Parameter(weight.type(dtype))
|
||||
self.input_scale = input_scale
|
||||
self.weight_scale = weight_scale
|
||||
self.gemm_output_scale = gemm_output_scale
|
||||
if bias is not None:
|
||||
self.bias = Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if residual is not None:
|
||||
self.residual = Parameter(residual)
|
||||
else:
|
||||
self.residual = None
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
if act_mode == 'none':
|
||||
self.act = None
|
||||
else:
|
||||
self.act = act_mode_dict[act_mode]
|
||||
# d = (a * b + bias) * alpha + c * beta
|
||||
# output = (input * weight + bias) * alpha + residual * beta
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = F.linear(input.type(self.dtype), self.weight, self.bias)
|
||||
if self.input_scale is not None:
|
||||
i_scale = self.input_scale.expand(self.weight_scale.shape[0], -1).transpose(0, 1)
|
||||
output = torch.mul(output, i_scale)
|
||||
if self.weight_scale is not None:
|
||||
output = torch.mul(output, self.weight_scale)
|
||||
if self.gemm_output_scale is not None:
|
||||
output = torch.mul(output, self.gemm_output_scale)
|
||||
output = torch.mul(output, self.alpha)
|
||||
if self.residual is not None:
|
||||
residual = torch.mul(self.residual, self.beta)
|
||||
output = torch.add(output, residual)
|
||||
if self.act is not None:
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
# for multiprocessing
|
||||
def assertTensorsEqual( a,
|
||||
b,
|
||||
prec=None,
|
||||
message='',
|
||||
allow_inf=False,
|
||||
use_MSE=False,
|
||||
use_RAE=False,
|
||||
use_RMA=False):
|
||||
tc = TestCase()
|
||||
if a.dtype == torch.bool:
|
||||
a = a.float()
|
||||
if b.dtype == torch.bool:
|
||||
b = b.float()
|
||||
epsilon = 1.0 / 16384
|
||||
tc.assertEqual(a.size(), b.size(), message)
|
||||
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
|
||||
if a.numel() > 0:
|
||||
# check that NaNs are in the same locations
|
||||
nan_mask = a != a
|
||||
tc.assertTrue(torch.equal(nan_mask, b != b), message)
|
||||
diff = a - b
|
||||
diff[nan_mask] = 0
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
a[nan_mask] = 0
|
||||
b[nan_mask] = 0
|
||||
# inf check if allow_inf=True
|
||||
if allow_inf:
|
||||
inf_mask = (a == float("inf")) | (a == float("-inf"))
|
||||
tc.assertTrue(torch.equal(inf_mask,
|
||||
(b == float("inf")) | (b == float("-inf"))),
|
||||
message)
|
||||
diff[inf_mask] = 0
|
||||
a[inf_mask] = 0
|
||||
b[inf_mask] = 0
|
||||
# TODO: implement abs on CharTensor
|
||||
if diff.is_signed() and 'CharTensor' not in diff.type():
|
||||
diff = diff.abs()
|
||||
if use_MSE:
|
||||
diff = diff.abs().pow(2).sum()
|
||||
a_pow_sum = a.pow(2).sum()
|
||||
if diff <= (2 * epsilon) * (2 * epsilon):
|
||||
diff = 0.0
|
||||
if a_pow_sum <= epsilon:
|
||||
a_pow_sum = a_pow_sum + epsilon
|
||||
diff = torch.div(diff, (a_pow_sum * 1.0))
|
||||
tc.assertLessEqual(diff.sqrt(), prec, message)
|
||||
elif use_RAE:
|
||||
diff = diff.abs().sum()
|
||||
a_sum = a.abs().sum()
|
||||
if a_sum == 0:
|
||||
tc.assertEqual(a, b, message)
|
||||
else:
|
||||
diff = torch.div(diff, a_sum)
|
||||
tc.assertLessEqual(diff, prec, message)
|
||||
elif use_RMA:
|
||||
a_mean = a.abs().mean()
|
||||
b_mean = b.abs().mean()
|
||||
if a_mean == 0:
|
||||
tc.assertEqual(a, b, message)
|
||||
else:
|
||||
diff = torch.div((a_mean - b_mean).abs(), a_mean)
|
||||
tc.assertLessEqual(diff, prec, message)
|
||||
else:
|
||||
max_err = diff.max()
|
||||
tc.assertLessEqual(max_err, prec, message)
|
||||
|
||||
def setup(rank, world_size, backend='cncl'):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '3458'
|
||||
dist.init_process_group(backend, rank=rank, world_size=world_size)
|
||||
torch_mlu.mlu.set_device(rank)
|
||||
|
||||
def cleanup():
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def generate_token_count(num_expert,
|
||||
total_token_count):
|
||||
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
|
||||
sum = torch.sum(token_count, dim=-1) * 1.0
|
||||
token_count *= total_token_count / sum.item()
|
||||
token_count = token_count.to(dtype=torch.int32)
|
||||
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
|
||||
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
|
||||
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
|
||||
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
|
||||
cusum_token_count[-1] = total_token_count
|
||||
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]
|
||||
|
||||
def generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype,
|
||||
quant_mode=False, offline=False, invalid_batch_size=0):
|
||||
q_heads = 1
|
||||
total_heads = q_heads + num_heads * 2
|
||||
max_bs = batch_size + 1
|
||||
|
||||
context_lens = torch.randint(size=(batch_size, ), low=1,
|
||||
high=cache_memory_len // 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
max_context_len = context_lens.max().item()
|
||||
max_seq_offset = max_context_len // 3 + 1
|
||||
|
||||
cache_bs_id = random.sample([*range(0, batch_size)], batch_size)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
if invalid_batch_size > 0:
|
||||
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch_size)] = -1
|
||||
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
|
||||
high=(cache_memory_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
if packed > 0:
|
||||
context = torch.randn((total_seqlen, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
context_seq_offsets = None
|
||||
else:
|
||||
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
cu_context_lens = context_lens
|
||||
context = context.to(dtype)
|
||||
key = context[..., q_heads:q_heads + num_heads, :]
|
||||
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
|
||||
|
||||
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
|
||||
cache_scale = None
|
||||
if quant_mode:
|
||||
cache = (cache - 0.5) * 256
|
||||
cache = cache.to(torch.int8)
|
||||
if offline:
|
||||
cache_scale = torch.randn((2, cache.shape[2], cache.shape[4]), dtype=torch.float, device='mlu')
|
||||
else:
|
||||
cache_scale = torch.randn((2, max_bs, num_heads, cache_memory_len), dtype=torch.float, device='mlu')
|
||||
else:
|
||||
cache = cache.to(dtype)
|
||||
|
||||
block_size = 16 if "MLU3" not in torch.mlu.get_device_name() else max_context_len
|
||||
min_blocks = (total_seqlen + block_size - 1) // block_size
|
||||
num_blocks = min(min_blocks + 10, 2 * min_blocks)
|
||||
num_slots = num_blocks * block_size
|
||||
slot_mapping = random.sample(range(num_slots), total_seqlen.item())
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
|
||||
slot_mapping[-1] = -1
|
||||
|
||||
return [key, value, cache[0], cache[1], cu_context_lens, max_context_len,
|
||||
packed > 0, context_seq_offsets, cache_bs_id, cache_seq_offsets,
|
||||
cache_scale, slot_mapping]
|
||||
|
||||
def fused_moe(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: Optional[torch.Tensor],
|
||||
bias2: Optional[torch.Tensor],
|
||||
residual: Optional[torch.Tensor],
|
||||
input_smooth: Optional[torch.Tensor],
|
||||
act_smooth: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
topk: int,
|
||||
renormalized: bool,
|
||||
gated: bool,
|
||||
act_mode: str,
|
||||
start_expert_id: int = 0,
|
||||
block_n: int = 0,
|
||||
cncl_comm: int = 0,
|
||||
w1_quant_flag: Optional[List] = None,
|
||||
w2_quant_flag: Optional[List] = None):
|
||||
dtype = hidden_states.dtype
|
||||
ori_input_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
||||
tokens = hidden_states.size(0)
|
||||
gating_output = gating_output.reshape(-1, gating_output.size(-1))
|
||||
residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None
|
||||
expert_num = gating_output.size(-1)
|
||||
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
|
||||
|
||||
per_token_sq = False
|
||||
# check quant
|
||||
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
||||
if all(x is not None for x in check_list):
|
||||
per_token_sq = True
|
||||
|
||||
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
||||
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
|
||||
"and absent at the same time.")
|
||||
|
||||
# softmax_topk
|
||||
reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalized)
|
||||
# gen_idx
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num)
|
||||
|
||||
if per_token_sq:
|
||||
if torch.mlu.get_device_name() == 'MLU370':
|
||||
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
quant_input, input_scale = tmo.moe_quantize(expand_hidden_states,
|
||||
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size])
|
||||
else:
|
||||
quant_input, input_scale = tmo.moe_quantize(hidden_states,
|
||||
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
|
||||
cusum_token_count[start_expert_id].unsqueeze(0))
|
||||
else:
|
||||
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
# group gemm
|
||||
if per_token_sq:
|
||||
gemm1_out = tmo.smooth_quant_group_gemm(quant_input,
|
||||
w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None,
|
||||
input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag)
|
||||
else:
|
||||
gemm1_out = tmo.group_gemm(expand_hidden_states,
|
||||
w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, tokens)
|
||||
# add_bias_active
|
||||
act_out = tmo.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
if per_token_sq:
|
||||
quant_input, input_scale = tmo.moe_quantize(act_out, act_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size])
|
||||
|
||||
if cncl_comm > 0:
|
||||
raise ValueError("not support communication and computing fusion currently.")
|
||||
else:
|
||||
if per_token_sq:
|
||||
gemm2_out = tmo.smooth_quant_group_gemm(quant_input,
|
||||
w2, token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag)
|
||||
else:
|
||||
gemm2_out = tmo.group_gemm(act_out,
|
||||
w2,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, tokens)
|
||||
|
||||
output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
|
||||
residual, cusum_token_count, start_expert_id,
|
||||
expert_size, bias2)
|
||||
return output.reshape(ori_input_shape)
|
||||
|
||||
def min_mem_size(shape, stride):
|
||||
if stride is None:
|
||||
mem_size = 0
|
||||
mem_size += shape.numel()
|
||||
else:
|
||||
mem_size = 1
|
||||
for k,v in zip(shape, stride):
|
||||
mem_size += (k - 1) * v
|
||||
return mem_size
|
||||
|
||||
def create_tensor(shape, dtype, is_contiguous, device, stride = None, mean=0, var=1, is_uniform=False, low=0, high=1):
|
||||
if is_contiguous:
|
||||
if dtype in (torch.int8, torch.uint8):
|
||||
t = torch.randint(-128, 127, shape, device=device).to(dtype)
|
||||
else:
|
||||
if is_uniform:
|
||||
t = torch.empty(shape, dtype=dtype, device=device).uniform_(low, high)
|
||||
else:
|
||||
t = torch.normal(mean, var, shape, dtype=dtype, device=device)
|
||||
else:
|
||||
mem_size = min_mem_size(shape, stride)
|
||||
if dtype in (torch.int8, torch.uint8):
|
||||
t = torch.randint(-128, 127, (mem_size,), device=device).to(dtype)
|
||||
else:
|
||||
if is_uniform:
|
||||
t = torch.empty((mem_size,), dtype=dtype, device=device).uniform_(low, high)
|
||||
else:
|
||||
t = torch.normal(mean, var, (mem_size,), dtype=dtype, device=device)
|
||||
t = t.as_strided(shape, stride)
|
||||
return t
|
||||
|
||||
def create_tensor_from_dic(dic:dict, mean=0, var=1, is_uniform=False, low=0, high=1):
|
||||
if dic['data'] is None:
|
||||
return None
|
||||
shape = dic['shape']
|
||||
dtype = dic['dtype']
|
||||
is_contiguous = dic['is_contiguous']
|
||||
device = dic['device']
|
||||
stride = dic['stride']
|
||||
return create_tensor(shape, dtype, is_contiguous, device, stride, mean, var, is_uniform, low, high)
|
||||
|
||||
def create_op_param(dic: dict):
|
||||
if dic['type'] in (list, tuple):
|
||||
return [create_op_param(elem) for elem in dic['data']] if dic['has_compound'] else dic['data']
|
||||
elif dic['type'] is dict:
|
||||
return {k:create_op_param(v) for k,v in dic['data'].items()}
|
||||
elif dic['type'] is torch.Tensor:
|
||||
if dic['data'] is None:
|
||||
return None
|
||||
else:
|
||||
if dic['dtype'] in (torch.int16, torch.int32, torch.int64):
|
||||
return dic['data']
|
||||
else:
|
||||
return create_tensor(dic['shape'], dic['dtype'], dic['is_contiguous'], dic['device'], dic['stride'])
|
||||
else:
|
||||
return dic['data']
|
||||
151
torch_mlu_ops-v1.3.2/tests/ops_pytest/run_gen_case.py
Normal file
151
torch_mlu_ops-v1.3.2/tests/ops_pytest/run_gen_case.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import os
|
||||
os.environ['TMO_GEN_CASE'] = '0'
|
||||
|
||||
import sys
|
||||
sys_args = sys.argv
|
||||
sys.argv = [sys_args.pop(0)] # prevent unittest printing help info
|
||||
|
||||
import copy
|
||||
import argparse
|
||||
import torch
|
||||
import torch_mlu
|
||||
import importlib
|
||||
import random
|
||||
from common_utils import create_op_param
|
||||
|
||||
def assert_tensor_equal(a: torch.Tensor, b: torch.Tensor, threshold):
|
||||
assert a.size() == b.size()
|
||||
a_ = a.cpu().reshape(-1).double()
|
||||
b_ = b.cpu().reshape(-1).double()
|
||||
nan_mask = a_ != a_
|
||||
assert torch.equal(nan_mask, b_ != b_), "tensor a and tensor b have different number of nan"
|
||||
diff = a_ - b_
|
||||
diff[nan_mask] = 0
|
||||
a_[nan_mask] = 0
|
||||
b_[nan_mask] = 0
|
||||
eps = 1e-10
|
||||
diff1 = diff.abs().sum() / (a_.abs().sum() + eps)
|
||||
diff2 = torch.sqrt((diff**2).sum() / ((a_**2).sum() + eps))
|
||||
print(f"[torch_mlu_ops] diff1: {diff1}, diff2: {diff2}")
|
||||
assert diff1 <= threshold, f"diff1: {diff1} <= threshold: {threshold}"
|
||||
assert diff2 <= threshold, f"diff2: {diff2} <= threshold: {threshold}"
|
||||
|
||||
def check_value_equal(x, y, threshold):
|
||||
assert type(x) == type(y)
|
||||
if type(x) is torch.Tensor:
|
||||
assert_tensor_equal(x, y, threshold)
|
||||
elif type(x) is list or type(x) is tuple:
|
||||
for i in range(len(x)):
|
||||
check_value_equal(x[i], y[i], threshold)
|
||||
elif x is not None:
|
||||
assert x == y
|
||||
|
||||
def check_equal(a, b, threshold):
|
||||
assert type(a) == type(b)
|
||||
if type(a) is tuple or type(a) is list:
|
||||
assert len(a) == len(b)
|
||||
for x, y in zip(a, b):
|
||||
check_value_equal(x, y, threshold)
|
||||
else:
|
||||
check_value_equal(a, b, threshold)
|
||||
|
||||
def get_base_obj(module_name, class_name):
|
||||
try:
|
||||
mod_ = importlib.import_module(module_name)
|
||||
cls_ = getattr(mod_, class_name)
|
||||
return cls_()
|
||||
except ImportError as e:
|
||||
print(f"Failed to import class '{class_name}' from module '{module_name}': {e}")
|
||||
return None
|
||||
except AttributeError as e:
|
||||
print(f"Module '{module_name}' does not have a class named '{class_name}': {e}")
|
||||
return None
|
||||
|
||||
def get_tmo_func(func_name):
|
||||
import torch_mlu_ops as tmo
|
||||
return getattr(tmo, func_name)
|
||||
|
||||
|
||||
op_map = {
|
||||
"active": ["test_active", "TestActive", 0.004],
|
||||
"apply_rotary": ["test_apply_rotary", "TestApplyRotaryOp", 0.003],
|
||||
"attention_project": ["test_attn_proj", "TestAttnProjOp", 0.003],
|
||||
"batch_matmul": ["test_batch_matmul", "TestBatchMatMulOp", 0.004],
|
||||
"copy_blocks": ["test_copy_blocks", "TestCopyBlocksOp", 0],
|
||||
"dequant_from_linear_cache": ["test_dequant_from_linear_cache", "TestDequantFromLinearCache", 0.001],
|
||||
"dequant_from_paged_cache": ["test_dequant_from_paged_cache", "TestDequantFromPagedCache", 0.001],
|
||||
"ffn": ["test_ffn", "TestFFNOp", 0.005],
|
||||
"flash_attention": ["test_flash_attention", "TestFlashAttnOp", 0.005],
|
||||
# "flash_attn_sq_mm_allreduce": ["test_flash_attn_sq_mm_allreduce", "TestFlashAttnSqMMAllreduce", 0.003],
|
||||
"fused_layer_norm": ["test_fused_layernorm", "TestFuseLayerNormOp", 0.003],
|
||||
"fused_moe": ["test_moe", "TestFusedMOEOp", 0.006],
|
||||
"fused_norm_attention_project": ["test_fused_attn_proj", "TestFusedNormAttnProjOp", 0.003],
|
||||
"fused_norm_residual_ffn": ["test_fused_ffn", "TestFusedNormResidualFFNoP", 0.003],
|
||||
"fused_rms_norm": ["test_fused_rmsnorm", "TestFuseRmsNormOp", 0.0032],
|
||||
"fused_rope": ["test_fused_rope", "TestFusedRopeOp", 0.003],
|
||||
"group_gemm": ["test_group_gemm", "TestGroupGemmOp", 0.006],
|
||||
"matmul": ["test_matmul", "TestMatMulOp", 0.003],
|
||||
# "matmul_allreduce": ["test_matmul_all_reduce", "TestMatMulAllReduceOp", 0.006],
|
||||
"moe_active": ["test_moe_add_bias_activation", "TestMoeActiveKernel", 0.003],
|
||||
"moe_cast_gating": ["test_moe_cast_gating", "TestMoeCastGating", 0.0001],
|
||||
"moe_combine_result": ["test_moe_combine_result", "TestCombineResult", 0.003],
|
||||
"moe_expand_input": ["test_moe_expand_input", "TestExpandInput", 0],
|
||||
"moe_gen_idx": ["test_moe_gen_idx", "TestGenIdx", 0],
|
||||
"moe_quantize": ["test_smooth_quant", "TestSmoothQuantOp", 0.01],
|
||||
"moe_softmax_topk": ["test_moe_softmax_topk", "TestSoftmaxTopkOp", 0.003],
|
||||
"offline_quant_to_linear_cache": ["test_offline_quant_to_linear_cache", "TestOfflineQuantToLinearCache", 0.03],
|
||||
"offline_quant_to_paged_cache": ["test_offline_quant_to_paged_cache", "TestOfflineQuantToPagedCache", 0.03],
|
||||
"per_token_smooth_quantize": ["test_per_token_smooth_quantize", "TestPerTokenSmoothQuantizeOp", 0.003],
|
||||
# "preload": ["test_preload", "TestPreloadOp", 1],
|
||||
"quant_to_linear_cache": ["test_quant_to_linear_cache", "TestQuantToLinearCache", 0.003],
|
||||
"quant_to_paged_cache": ["test_quant_to_paged_cache", "TestQuantToPagedCache", 0.009],
|
||||
"quantize": ["test_quantize", "TestQuantizeOp", 0.003],
|
||||
"reshape_linear_cache": ["test_reshape_linear_cache", "TestReshapeLinearCache", 0],
|
||||
"reshape_paged_cache": ["test_reshape_paged_cache", "TestReshapePagedCacheOp", 0],
|
||||
"single_query_cached_kv_attn": ["test_single_query_cached_kv_attn", "TestSingleQueryAttnOp", 0.003],
|
||||
"single_query_mixed_cached_kv_attn": ["test_single_query_mixed_cached_kv_attn", "TestSingleQueryMixedKVAttnOp", 0.003],
|
||||
"smooth_quant_group_gemm": ["test_smooth_quant_group_gemm", "TestSmoothQuantGroupGemmOp", 0.006],
|
||||
"smooth_quant_matmul": ["test_smooth_quant_matmul", "TestSmoothQuantMatmulOp", 0.006],
|
||||
# "smooth_quant_matmul_allreduce": ["test_quant_matmul_all_reduce", "TestGptQuantMatmulOp", 0.006],
|
||||
"swap_blocks": ["test_swap_blocks", "TestSwapBlocksOp", 0],
|
||||
"update_out_and_lse": ["test_update_out_and_lse", "TestUpdateOutAndLse", 0.005],
|
||||
"weight_only_quant_matmul": ["test_weight_only_quant_matmul", "TestWeightOnlyQuantMatmulOp", 0.004],
|
||||
}
|
||||
|
||||
|
||||
def run_case(pt_case):
|
||||
op_name = pt_case.pop('op')
|
||||
op_obj = get_base_obj(op_map[op_name][0], op_map[op_name][1])
|
||||
if hasattr(op_obj, "run_gen_case"):
|
||||
op_obj.run_gen_case(pt_case)
|
||||
else:
|
||||
dump_data = pt_case.pop('dump_data')
|
||||
if dump_data:
|
||||
params = pt_case
|
||||
else:
|
||||
params = dict()
|
||||
for k,v in pt_case.items():
|
||||
params[k] = create_op_param(v)
|
||||
params_bak = copy.deepcopy(params)
|
||||
result_tmo = get_tmo_func(op_name)(**params)
|
||||
result_base = op_obj.op_impl_base(*params_bak.values())
|
||||
check_equal(result_tmo, result_base, op_map[op_name][-1])
|
||||
|
||||
def main():
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--case_path', required=True, type=str, help='specify the case path')
|
||||
parser.add_argument('--detail', action="store_true", help="show content of pt file")
|
||||
args = parser.parse_args(args=sys_args)
|
||||
pt_case = torch.load(args.case_path)
|
||||
if args.detail:
|
||||
for k,v in pt_case.items():
|
||||
print(f"{k}: {v}")
|
||||
exit(0)
|
||||
print(f"[torch_mlu_ops] run {args.case_path} ...")
|
||||
run_case(pt_case)
|
||||
print(f"[torch_mlu_ops] run {args.case_path} successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
torch_mlu_ops-v1.3.2/tests/ops_pytest/run_test.sh
Executable file
22
torch_mlu_ops-v1.3.2/tests/ops_pytest/run_test.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
||||
tmo_ops_case=$(find "${SCRIPT_DIR}" -name "test*.py")
|
||||
coverage=${1}
|
||||
|
||||
for sc in ${tmo_ops_case}
|
||||
do
|
||||
echo -n "${sc} "
|
||||
echo -n "Testing...."
|
||||
if [ "${coverage}" = "coverage" ];then
|
||||
coverage run -a ${sc}
|
||||
else
|
||||
python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1
|
||||
fi
|
||||
if [ $? == 0 ];then
|
||||
echo -e "\033[32m success \033[0m"
|
||||
else
|
||||
echo -e "\033[31m failed \033[0m"
|
||||
fi
|
||||
done
|
||||
echo "End of pytest..."
|
||||
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py
Normal file
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
import time
|
||||
import os
|
||||
|
||||
class TestActive(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
input = create_tensor_from_dic(dic['input'])
|
||||
act_mode = dic['act_mode']['data']
|
||||
is_gated = dic['is_gated']['data']
|
||||
active_coef = dic['active_coef']['data']
|
||||
self.launch(input, act_mode, is_gated, active_coef)
|
||||
|
||||
def launch(self, *args):
|
||||
torch_out = self.op_impl_base(*args)
|
||||
tmo_out = tmo.active(*args)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
input, act_mode, is_gated, active_coef = args
|
||||
channel = input.size(-1)
|
||||
if act_mode == "gelu":
|
||||
if is_gated:
|
||||
out = torch.nn.functional.gelu(input[..., :channel//2])
|
||||
out *= input[..., channel//2:]
|
||||
else:
|
||||
out = torch.nn.functional.gelu(input)
|
||||
else:
|
||||
if act_mode == "silu":
|
||||
active_coef = 1.0
|
||||
elif act_mode == "quick_gelu":
|
||||
active_coef = 1.702
|
||||
|
||||
def swish(input, coef):
|
||||
return input * torch.sigmoid(coef * input)
|
||||
|
||||
if is_gated:
|
||||
out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:]
|
||||
else:
|
||||
out = swish(input, active_coef)
|
||||
return out
|
||||
|
||||
def test_active_random(self):
|
||||
for _ in range(500):
|
||||
dtype_list = [torch.float, torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
input_dtype = random.choice(dtype_list)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 2048)
|
||||
hidden_size = random.randrange(2, 8192, 2)
|
||||
is_gated = random.choice([True, False])
|
||||
act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish'])
|
||||
if act_mode == 'silu':
|
||||
active_coef = 1.0
|
||||
elif act_mode == 'quick_gelu':
|
||||
active_coef = 1.702
|
||||
else:
|
||||
active_coef = random.uniform(0, 1)
|
||||
print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \
|
||||
[batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True)
|
||||
input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu")
|
||||
self.launch(input, act_mode, is_gated, active_coef)
|
||||
|
||||
def test_inductor(self):
|
||||
input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu")
|
||||
output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu")
|
||||
is_gated = True
|
||||
act_mode = 'silu'
|
||||
coef = 1.0
|
||||
args = (input, output, None, None, act_mode, is_gated, 0, 0, coef)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestActive))
|
||||
152
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_apply_rotary.py
Executable file
152
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_apply_rotary.py
Executable file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
def gen_args(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype):
|
||||
cu_context_lens = None
|
||||
total_seq_len = bs * seq_len
|
||||
max_context_len = seq_len
|
||||
if packed:
|
||||
context_lens = torch.randint(size=(bs, ), low=1, high=seq_len+1, dtype=torch.int32, device='mlu')
|
||||
total_seq_len = context_lens.sum().item()
|
||||
max_context_len = context_lens.max().item()
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1, dtype=torch.int32)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0)
|
||||
|
||||
context_shape = (total_seq_len, q_heads + kv_heads + 1, head_size) if packed else \
|
||||
(bs, seq_len, q_heads + kv_heads + 1, head_size)
|
||||
context = torch.randn(size=context_shape, dtype=dtype).mlu()
|
||||
qk = context[..., 0 : q_heads + kv_heads, :]
|
||||
|
||||
position_id = None
|
||||
if discrete:
|
||||
position_id = torch.randint(0, max_context_len, size=(total_seq_len,), dtype=torch.int32, device="mlu")
|
||||
else:
|
||||
position_id = torch.randint(0, max_context_len, size=(bs,), dtype=torch.int32, device="mlu")
|
||||
|
||||
rope_seqlen = seq_len * 2
|
||||
cos_shape = (bs, rope_seqlen, rope_dim) if dynamic_ntk else (rope_seqlen, rope_dim)
|
||||
cos_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu")
|
||||
sin_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu")
|
||||
|
||||
return (qk, sin_cache, cos_cache, position_id, cu_context_lens, \
|
||||
interleaved, discrete, dynamic_ntk, max_context_len)
|
||||
|
||||
class TestApplyRotaryOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
def rotate(x: torch.Tensor, interleaved: bool):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
y = torch.empty_like(x)
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
y[..., ::2], y[..., 1::2] = -x2, x1
|
||||
return y
|
||||
input, sin_cache, cos_cache, position_ids, cu_seqlen, interleaved, discrete, dynamic_ntk, max_seqlen = args
|
||||
packed = input.dim() == 3
|
||||
rope_dim = sin_cache.shape[-1]
|
||||
batch_size = cu_seqlen.shape[0] - 1 if packed else input.shape[0]
|
||||
sin_cache_float = sin_cache.float()
|
||||
cos_cache_float = cos_cache.float()
|
||||
cu_seqlen_cpu = cu_seqlen.cpu() if cu_seqlen is not None else None
|
||||
position_ids_cpu = position_ids.cpu() if position_ids is not None else None
|
||||
for i in range(batch_size):
|
||||
input_i = input[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]] if packed else input[i]
|
||||
input_i = input_i[..., 0:rope_dim]
|
||||
sin_cache_i = sin_cache_float[i] if dynamic_ntk else sin_cache_float
|
||||
cos_cache_i = cos_cache_float[i] if dynamic_ntk else cos_cache_float
|
||||
seq = input_i.shape[0]
|
||||
if discrete:
|
||||
if packed:
|
||||
position_id_i = position_ids_cpu[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]]
|
||||
else:
|
||||
position_id_i = position_ids_cpu.view(batch_size, -1)[i]
|
||||
sin_cache_i = sin_cache_i[position_id_i]
|
||||
cos_cache_i = cos_cache_i[position_id_i]
|
||||
else:
|
||||
if position_ids_cpu is None:
|
||||
sin_cache_i = sin_cache_i[:seq]
|
||||
cos_cache_i = cos_cache_i[:seq]
|
||||
else:
|
||||
pos_id = position_ids_cpu[i].item()
|
||||
sin_cache_i = sin_cache_i[pos_id : seq + pos_id]
|
||||
cos_cache_i = cos_cache_i[pos_id : seq + pos_id]
|
||||
rot = rotate(input_i.float(), interleaved)
|
||||
output_i = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1)
|
||||
input_i[:] = output_i.to(input.dtype)
|
||||
return input
|
||||
|
||||
def test_apply_rotary(self):
|
||||
bs_list = [1, 8]
|
||||
seq_len_list = [1, 128, 1024]
|
||||
q_heads_list = [8, 32]
|
||||
kv_heads_list = [1]
|
||||
head_size_list = [128, 256]
|
||||
rope_dim_list = [256, 128, 64, 24]
|
||||
is_interleaved_list = [True, False]
|
||||
discrete_list = [True, False]
|
||||
dynamic_ntk_list = [True, False]
|
||||
packed_list = [True, False]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
args = product(bs_list, seq_len_list, q_heads_list, kv_heads_list, head_size_list, \
|
||||
rope_dim_list, is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list, dtype_list)
|
||||
for bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype in args:
|
||||
print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, "
|
||||
"rope_dim: {}, interleaved: {}, discrete: {}, dynamic_ntk: {}, packed: {}, dtype: {} testing...".format( \
|
||||
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype), flush=True)
|
||||
if rope_dim > head_size:
|
||||
print("rope_dim = {}, head_size = {},rope_dim should less than head_size".format(rope_dim, head_size))
|
||||
continue
|
||||
|
||||
qk, sin_cache, cos_cache, position_id, cu_context_lens, \
|
||||
interleaved, _, _, max_context_len = gen_args(bs, seq_len, q_heads, kv_heads, head_size,
|
||||
rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype)
|
||||
|
||||
qk_base = qk.clone()
|
||||
self.op_impl_base(qk_base, sin_cache, cos_cache, position_id, cu_context_lens, \
|
||||
interleaved, discrete, dynamic_ntk, max_context_len)
|
||||
|
||||
qk_out = ops.apply_rotary(qk, sin_cache, cos_cache, position_id, cu_context_lens, \
|
||||
interleaved, discrete, dynamic_ntk, max_context_len)
|
||||
|
||||
self.assertTensorsEqual(qk_out.cpu().float(), qk_base.cpu().float(), 0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
bs, seq, head_num, head_size, dynamic_ntk, rotary_seqlen, rotary_dim, discrete = 1, 1024, 8, 128, False, 512, 128, True
|
||||
|
||||
q = torch.randn(bs, seq, head_num, head_size, dtype=torch.half, device="mlu")
|
||||
sin = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu")
|
||||
cos = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu")
|
||||
self.assertException("discrete must be false if position ids is null.", ops.apply_rotary,
|
||||
q, sin, cos, None, None, False, discrete, dynamic_ntk, seq)
|
||||
|
||||
discrete = False
|
||||
self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary,
|
||||
q, sin, cos, None, None, False, discrete, dynamic_ntk, seq)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device="mlu")
|
||||
self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary,
|
||||
q, sin, cos, position_ids, None, False, discrete, dynamic_ntk, seq)
|
||||
|
||||
def test_inductor(self):
|
||||
is_interleaved_list = [True, False]
|
||||
discrete_list = [True, False]
|
||||
dynamic_ntk_list = [True, False]
|
||||
packed_list = [True, False]
|
||||
params = product(is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list)
|
||||
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype = 8, 1024, 8, 1, 256, 24, torch.float16
|
||||
for interleaved, discrete, dynamic_ntk, packed in params:
|
||||
print(f"==== check apply_rotary interleaved: {interleaved}, discrete: {discrete}, dynamic_ntk: {dynamic_ntk}, packed: {packed} ====")
|
||||
args = gen_args(bs, seq_len, q_heads, kv_heads, head_size,
|
||||
rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.apply_rotary, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestApplyRotaryOp))
|
||||
44
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_attn_proj.py
Executable file
44
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_attn_proj.py
Executable file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class TestAttnProjOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, weight, bias, residual, alpha, beta = args
|
||||
proj = F.linear(input, weight, bias)
|
||||
output = alpha * proj + beta * residual
|
||||
return output
|
||||
|
||||
def test_attn_proj(self):
|
||||
N, T, input_size, hidden_size, alpha, beta = 32, 129, 2048, 4096, 0.5, 0.1
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
|
||||
N, T, input_size, hidden_size), flush=True)
|
||||
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
|
||||
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
|
||||
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
|
||||
residual = torch.randn(N, T, hidden_size * 3, dtype=dtype, device="mlu")
|
||||
torch_out = self.op_impl_base(input, weight, bias, residual, alpha, beta)
|
||||
tmo_out = ops.attention_project(input, weight, bias, residual, alpha, beta)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
N, T, input_size, hidden_size, dtype = 32, 129, 2048, 4096, torch.half
|
||||
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
|
||||
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
|
||||
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
|
||||
args = (input, weight, bias, None, None, None,
|
||||
None, None, None, None, "nthc", 1,
|
||||
1e-5, 1., 0., False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestAttnProjOp))
|
||||
83
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_batch_matmul.py
Normal file
83
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_batch_matmul.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
|
||||
class TestBatchMatMulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
a, b, c, alpha, beta, a_scale, b_scale, trans_a, trans_b = args
|
||||
if trans_a:
|
||||
a = a.transpose(1, 2)
|
||||
if trans_b:
|
||||
b = b.transpose(1, 2)
|
||||
if c is None:
|
||||
output_dtype = a.dtype
|
||||
else:
|
||||
output_dtype = c.dtype
|
||||
if a_scale is not None:
|
||||
a = torch.div(a, a_scale).to(output_dtype)
|
||||
b = torch.div(b, b_scale).to(output_dtype)
|
||||
output = alpha * torch.bmm(a, b)
|
||||
if c is not None:
|
||||
output += beta * c
|
||||
c.copy_(output)
|
||||
return output
|
||||
|
||||
def test_batch_matmul(self):
|
||||
batch_list = [5]
|
||||
mat_m_list = [32]
|
||||
mat_n_list = [256]
|
||||
mat_k_list = [128]
|
||||
has_res_list = [False, True]
|
||||
trans_a_list = [False, True]
|
||||
trans_b_list = [False, True]
|
||||
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
alpha = 0.625
|
||||
beta = 1.0
|
||||
args = product(batch_list, mat_m_list, mat_n_list, mat_k_list, has_res_list, dtype_list, trans_a_list, trans_b_list)
|
||||
|
||||
for batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b in args:
|
||||
print("batch={}, m={}, n={}, k={}, has_res={}, dtype={}, trans_a={}, trans_b={} testing...".format(
|
||||
batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b), flush=True)
|
||||
shape_a, shape_b = (batch, mat_m, mat_k), (batch, mat_k, mat_n)
|
||||
if trans_a:
|
||||
shape_a = (batch, mat_k, mat_m)
|
||||
if trans_b:
|
||||
shape_b = (batch, mat_n, mat_k)
|
||||
input = torch.randn(shape_a, dtype=dtype, device='mlu')
|
||||
weight = torch.randn(shape_b, dtype=dtype, device='mlu')
|
||||
input8, a_scale = QuantByTensor(input, 8)
|
||||
weight8, b_scale = QuantByTensor(weight, 8)
|
||||
residual = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
res_bak = residual.clone()
|
||||
tmo_output_int8 = torch.zeros((batch, mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
torch_output_int8 = tmo_output_int8.clone()
|
||||
output = self.op_impl_base(input, weight,
|
||||
residual if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b)
|
||||
tmo_output = ops.batch_matmul(input, weight,
|
||||
res_bak if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b)
|
||||
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
if dtype != torch.bfloat16:
|
||||
self.op_impl_base(input8, weight8, torch_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b)
|
||||
ops.batch_matmul(input8, weight8, tmo_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b)
|
||||
self.assertTensorsEqual(tmo_output_int8.cpu().float(), torch_output_int8.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, mat_m, mat_n, mat_k, alpha, beta, dtype = 6, 64, 256, 128, 0.8, 0.3, torch.float16
|
||||
a = torch.randn((batch, mat_m, mat_k), dtype=dtype, device='mlu')
|
||||
b = torch.randn((batch, mat_n, mat_k), dtype=dtype, device='mlu')
|
||||
c = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
args = (a, b, c, alpha, beta, 1.0, 1.0, False, True)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.batch_matmul, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestBatchMatMulOp))
|
||||
167
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_copy_blocks.py
Executable file
167
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_copy_blocks.py
Executable file
@@ -0,0 +1,167 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
from typing import List, Tuple
|
||||
import os
|
||||
import copy
|
||||
|
||||
class TestCopyBlocksOp(BtTestCase):
|
||||
def create_kv_caches(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = head_size**-0.5
|
||||
# vllm scale
|
||||
# x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
# key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache_shape = (num_blocks, num_heads, block_size, head_size)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
|
||||
info = torch.iinfo(dtype)
|
||||
key_cache = torch.randint(info.min, info.max, size=key_cache_shape, dtype=dtype).mlu()
|
||||
else:
|
||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype).mlu()
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
|
||||
info = torch.iinfo(dtype)
|
||||
value_cache = torch.randint(info.min, info.max, size=value_cache_shape, dtype=dtype).mlu()
|
||||
else:
|
||||
value_cache = torch.empty(size=value_cache_shape, dtype=dtype).mlu()
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
|
||||
return key_caches, value_caches
|
||||
|
||||
def create_block_mapping(self, num_blocks, num_mappings, seed = 0):
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
assert 3 * num_mappings <= num_blocks
|
||||
block_mapping = {}
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping[src] = [dst1, dst2]
|
||||
return block_mapping
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
k_caches, v_caches, block_mapping = args
|
||||
for src, dsts in block_mapping.items():
|
||||
srcs = [src for i in range(len(dsts))]
|
||||
srcs_ind = torch.tensor(srcs, dtype=torch.int64)
|
||||
dsts_ind = torch.tensor(dsts, dtype=torch.int64)
|
||||
for key_cache in k_caches:
|
||||
key_cache[dsts_ind] = key_cache[srcs_ind]
|
||||
if v_caches is not None:
|
||||
for value_cache in v_caches:
|
||||
value_cache[dsts_ind] = value_cache[srcs_ind]
|
||||
return (k_caches, v_caches) if v_caches is not None else k_caches
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_copy_blocks due to ASan issues")
|
||||
def test_copy_blocks(self):
|
||||
dtype_list = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
num_tokens_list = [83]
|
||||
num_heads_list = [8]
|
||||
head_size_list = [64, 512]
|
||||
num_blocks_list = [3600]
|
||||
block_size_list = [8]
|
||||
num_layers_list = [1, 6]
|
||||
num_mappings_list = [128, 600]
|
||||
seeds_list = [0]
|
||||
only_key_cache_list = [True, False]
|
||||
|
||||
args = product(num_tokens_list, num_heads_list, head_size_list, num_blocks_list, block_size_list, dtype_list,
|
||||
num_layers_list, num_mappings_list, seeds_list)
|
||||
for num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed in args:
|
||||
print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, dtype: {}, num_layers: {}, \
|
||||
num_mappings: {}, seed: {}, testing...".format(
|
||||
num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed), flush=True)
|
||||
block_mapping = self.create_block_mapping(num_blocks, num_mappings, seed)
|
||||
only_key_cache = random.choice(only_key_cache_list)
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = self.create_kv_caches(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.cpu().clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.cpu().clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
if only_key_cache:
|
||||
value_caches = None
|
||||
cloned_value_caches = None
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
self.op_impl_base(cloned_key_caches, cloned_value_caches, block_mapping)
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
self.assertTensorsEqual(key_cache.cpu().float(), cloned_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
if not only_key_cache:
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
self.assertTensorsEqual(value_cache.cpu().float(), cloned_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues")
|
||||
def test_prevent(self):
|
||||
print("test_copy_block: test_prevent...")
|
||||
num_blocks, block_size, head_num, head_size, block_mapping = 384, 6, 32, 128, 128
|
||||
k_cache = torch.randn(num_blocks * head_num, block_size, head_size, dtype=torch.half, device="mlu")
|
||||
v_cache = torch.randn(num_blocks * head_num, head_size, block_size + 1, dtype=torch.float, device="mlu")
|
||||
key_caches = [k_cache,]
|
||||
value_caches = None
|
||||
block_mapping = self.create_block_mapping(num_blocks, block_mapping)
|
||||
self.assertException("every layer k_cache must be 4d.", ops.copy_blocks,
|
||||
key_caches, value_caches, block_mapping)
|
||||
k_cache = k_cache.reshape(num_blocks, head_num, block_size, head_size)
|
||||
key_caches = [k_cache, k_cache,]
|
||||
value_caches = [v_cache]
|
||||
self.assertException("k_caches size must equal to v_caches size if v_caches is not none.",
|
||||
ops.copy_blocks, key_caches, value_caches, block_mapping)
|
||||
key_caches = [k_cache]
|
||||
value_caches = [v_cache]
|
||||
self.assertException("the data type of k_caches and v_caches are not the same.",
|
||||
ops.copy_blocks, key_caches, value_caches, block_mapping)
|
||||
value_caches[0] = value_caches[0].to(torch.half)
|
||||
self.assertException("every layer k_cache dim must equal to v_cache dim.",
|
||||
ops.copy_blocks, key_caches, value_caches, block_mapping)
|
||||
value_caches[0] = value_caches[0].reshape(num_blocks, head_num, head_size, block_size + 1)
|
||||
self.assertException("the block_size of k_caches and v_caches are not the same.",
|
||||
ops.copy_blocks, key_caches, value_caches, block_mapping)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues")
|
||||
def test_inductor(self):
|
||||
num_heads, head_size, num_blocks, block_size, num_layers, num_mappings = 8, 64, 384, 8, 1, 128
|
||||
key_caches, value_caches = self.create_kv_caches(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, torch.float16, 0)
|
||||
block_mapping = self.create_block_mapping(num_blocks, num_mappings)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.copy_blocks, (key_caches, value_caches, block_mapping))
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestCopyBlocksOp))
|
||||
280
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_linear_cache.py
Executable file
280
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_linear_cache.py
Executable file
@@ -0,0 +1,280 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
import torch_mlu_ops as ops
|
||||
|
||||
from common_utils import *
|
||||
|
||||
def gen_args(max_batch_size,
|
||||
batch_size,
|
||||
max_context_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
cache_mem_len,
|
||||
head_size,
|
||||
group_size,
|
||||
use_seq_offset,
|
||||
dtype,
|
||||
quant_mode,
|
||||
quant_bit,
|
||||
has_value = True):
|
||||
# Preprocess arguments
|
||||
assert max_batch_size >= batch_size, \
|
||||
"max_batch_size should greater than or equal to batch_size."
|
||||
assert cache_mem_len >= max_context_len, \
|
||||
"cache_mem_len should greater then or equal to max_context_len."
|
||||
assert head_size % group_size == 0, \
|
||||
"head_size should be a multiply of groupwise."
|
||||
total_heads = head_num_q + head_num_kv * 2
|
||||
max_seq_offset = cache_mem_len - max_context_len
|
||||
# Generates key and cache from context
|
||||
context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1,
|
||||
dtype=torch.int32, device="mlu")
|
||||
if use_seq_offset:
|
||||
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device="mlu")
|
||||
else:
|
||||
context_paddings = torch.zeros_like(context_lens)
|
||||
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
|
||||
context_seq_offset[1:] = cu_context_lens[:-1]
|
||||
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
|
||||
key = context[..., head_num_q:head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
|
||||
|
||||
# Generates key_cache and value_cache
|
||||
cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
|
||||
cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
|
||||
dtype=torch.int32, device="mlu")
|
||||
if quant_bit == 4:
|
||||
key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
|
||||
head_num_kv, cache_mem_len, head_size // 2), device="mlu")
|
||||
value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
|
||||
head_num_kv, cache_mem_len // 2, head_size), device="mlu")
|
||||
key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
|
||||
else:
|
||||
cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
|
||||
low=-128, high=127, dtype=torch.int32, device="mlu")
|
||||
cache = cache.to(torch.int8)
|
||||
key_cache, value_cache = cache[[0, 1]]
|
||||
|
||||
# Generates key_cache_scale and value_cache_scale
|
||||
if quant_mode == 0: # quant_mode == 0 is per channel
|
||||
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
|
||||
else: # quant_mode != 1 (== 1 for extend) is per head
|
||||
cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
|
||||
dtype=torch.float, device="mlu")
|
||||
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
|
||||
|
||||
# Prepare arguments
|
||||
if has_value == False:
|
||||
value = None
|
||||
value_cache = None
|
||||
value_cache_scale = None
|
||||
args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale]
|
||||
args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None,
|
||||
cache_bs_id, cache_seq_offset]
|
||||
args += [quant_mode, quant_bit]
|
||||
return args
|
||||
|
||||
class TestDequantFromLinearCache(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
key = create_tensor_from_dic(dic['key'])
|
||||
value = create_tensor_from_dic(dic['value'])
|
||||
key_cache = create_tensor_from_dic(dic['key_cache'])
|
||||
value_cache = create_tensor_from_dic(dic['value_cache'])
|
||||
key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale'])
|
||||
value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale'])
|
||||
context_lengths = dic['context_lengths']['data']
|
||||
max_context_len = dic['max_context_len']['data']
|
||||
context_seq_offset = dic['context_seq_offset']['data']
|
||||
cache_bs_id = dic['cache_bs_id']['data']
|
||||
cache_seq_offset = dic['cache_seq_offset']['data']
|
||||
quant_mode = dic['quant_mode']['data']
|
||||
quant_bit = dic['quant_bit']['data']
|
||||
|
||||
self.launch(key, value, key_cache, value_cache, key_cache_quant_scale,
|
||||
value_cache_quant_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode, quant_bit)
|
||||
|
||||
def launch(self, *args):
|
||||
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
|
||||
context_lengths, max_context_len, context_seq_offset, cache_bs_id, \
|
||||
cache_seq_offset, quant_mode, quant_bit = args
|
||||
if value is None or value_cache is None or value_cache_scale is None:
|
||||
has_value = False
|
||||
else:
|
||||
has_value = True
|
||||
ops.dequant_from_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, cache_bs_id, cache_seq_offset,
|
||||
quant_mode, quant_bit)
|
||||
key_clone = key.clone()
|
||||
if has_value:
|
||||
value_clone = value.clone()
|
||||
self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode,
|
||||
quant_bit)
|
||||
self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True)
|
||||
if has_value:
|
||||
self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001,
|
||||
use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
def dequant_from_cache(quant_data: torch.Tensor,
|
||||
scale_data: torch.Tensor,
|
||||
quant_mode: int):
|
||||
quant_data_fp32 = quant_data.clone().to(torch.float)
|
||||
scale_data_fp32 = scale_data.clone().to(torch.float)
|
||||
if quant_mode == 0: # per channel
|
||||
scale_data_fp32 = scale_data[..., None, :]
|
||||
else: # per head/token
|
||||
scale_data_fp32 = scale_data[..., None]
|
||||
dequant_data_fp32 = quant_data_fp32 * scale_data_fp32
|
||||
|
||||
return dequant_data_fp32
|
||||
|
||||
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
|
||||
context_lengths, max_context_len, context_seq_offset, cache_bs_id, \
|
||||
cache_seqlen_offset, quant_mode, quant_bit = args
|
||||
batch_size = context_lengths.size(0)
|
||||
if context_seq_offset is None:
|
||||
cu_seq_offset = torch.cumsum(context_lengths, dim=-1)
|
||||
context_seq_offset = torch.zeros_like(cu_seq_offset)
|
||||
context_seq_offset[1:] = cu_seq_offset[:-1]
|
||||
|
||||
total = 0
|
||||
cache_mem_len = key_cache.size(2)
|
||||
for i in range(batch_size):
|
||||
context_len = context_lengths[i].item()
|
||||
seq_begin = context_seq_offset[i].item()
|
||||
seq_end = seq_begin + context_len
|
||||
total += context_len
|
||||
cache_id = cache_bs_id[i] if cache_bs_id is not None else i
|
||||
cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
|
||||
cache_seq_end = cache_seq_begin + context_len
|
||||
key_i = key[seq_begin:seq_end].transpose(1, 0)
|
||||
if quant_bit == 4:
|
||||
key_cache_i_temp = key_cache[cache_id, :, cache_seq_begin:cache_seq_end]
|
||||
cache_size = list(key_cache_i_temp.size())
|
||||
cache_size[-1] *= 2
|
||||
key_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu")
|
||||
key_cache_i[...,::2] = key_cache_i_temp << 4 >> 4
|
||||
key_cache_i[...,1::2] = key_cache_i_temp >> 4
|
||||
else:
|
||||
key_cache_i = key_cache[cache_id, :, cache_seq_begin:cache_seq_end]
|
||||
# We use negatice cache_seq_offset to skip unused batch
|
||||
if cache_seq_begin < 0 or cache_seq_end > cache_mem_len:
|
||||
continue
|
||||
# dequant key from cache
|
||||
if quant_mode == 0:
|
||||
key_cache_scale_i = key_cache_scale
|
||||
else:
|
||||
key_cache_scale_i = key_cache_scale[cache_id, :, cache_seq_begin:cache_seq_end]
|
||||
|
||||
dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode)
|
||||
key_i[...] = dequant_key_i.to(key_i.dtype)
|
||||
|
||||
# dequant value from cache
|
||||
if not (value_cache is None or value is None or value_cache_scale is None):
|
||||
value_i = value[seq_begin:seq_end].transpose(1, 0)
|
||||
if quant_bit == 4:
|
||||
pad_front = cache_seq_begin % 2
|
||||
pad_back = cache_seq_end % 2
|
||||
cache_seq_begin_temp = int(cache_seq_begin // 2)
|
||||
cache_seq_end_temp = int(math.ceil(cache_seq_end / 2.0))
|
||||
|
||||
value_cache_i_temp = value_cache[cache_id, :,
|
||||
cache_seq_begin_temp:cache_seq_end_temp]
|
||||
cache_size = list(value_cache_i_temp.size())
|
||||
cache_size[-2] *= 2
|
||||
value_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu")
|
||||
value_cache_i[...,::2,:] = value_cache_i_temp << 4 >> 4
|
||||
value_cache_i[...,1::2,:] = value_cache_i_temp >> 4
|
||||
if pad_front:
|
||||
value_cache_i = value_cache_i[...,1:,:]
|
||||
if pad_back:
|
||||
value_cache_i = value_cache_i[...,:-1,:]
|
||||
else:
|
||||
value_cache_i = value_cache[cache_id, :, cache_seq_begin:cache_seq_end]
|
||||
if quant_mode == 0:
|
||||
value_cache_scale_i = value_cache_scale
|
||||
else:
|
||||
value_cache_scale_i = value_cache_scale[cache_id, :,
|
||||
cache_seq_begin:cache_seq_end]
|
||||
dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i,
|
||||
quant_mode)
|
||||
value_i[...] = dequant_value_i.to(value.dtype)
|
||||
|
||||
def test_dequant_from_linear_cache(self):
|
||||
test_cases = 100
|
||||
head_size_times = 16
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
max_batch_size_list = torch.randint(low=32, high=64, size=[test_cases],
|
||||
dtype=torch.int32)
|
||||
batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
|
||||
max_context_len_list = torch.randint(low=2, high=2048, size=[test_cases],
|
||||
dtype=torch.int32)
|
||||
head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
|
||||
head_num_kv_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=[test_cases],
|
||||
dtype=torch.int32) * head_size_times
|
||||
cache_mem_len_list = torch.randint(low=1024, high=2048, size=[test_cases],
|
||||
dtype=torch.int32) * 2
|
||||
quant_mode_list = np.random.choice([0, 1], test_cases)
|
||||
quant_bit_list = np.random.choice([4, 8], test_cases)
|
||||
use_offset_list = np.random.choice([False, True], test_cases)
|
||||
has_value_list = np.random.choice([False, True], test_cases)
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
dtype_list = dtype_list[:-1] if "MLU3" in mlu_name else dtype_list
|
||||
dtype_list = np.random.choice(dtype_list, test_cases)
|
||||
for i in range(test_cases):
|
||||
max_batch_size = max_batch_size_list[i].item()
|
||||
batch_size = batch_size_list[i].item()
|
||||
head_num_q = head_num_q_list[i].item()
|
||||
max_context_len = max_context_len_list[i].item()
|
||||
head_num_kv = head_num_kv_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
cache_mem_len = cache_mem_len_list[i].item()
|
||||
quant_mode = quant_mode_list[i]
|
||||
quant_bit = quant_bit_list[i]
|
||||
use_seq_offset = use_offset_list[i]
|
||||
has_value = has_value_list[i]
|
||||
dtype = dtype_list[i]
|
||||
if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
|
||||
* head_size >= 2**31 - 1):
|
||||
print("large tensor is not support on {}, skip".format(mlu_name))
|
||||
continue
|
||||
|
||||
print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, "
|
||||
"quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size,
|
||||
max_context_len, quant_mode, quant_bit, dtype))
|
||||
|
||||
torch.manual_seed(2766)
|
||||
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
|
||||
cache_mem_len, head_size, head_size, use_seq_offset, dtype, quant_mode,
|
||||
quant_bit, has_value)
|
||||
self.launch(*args)
|
||||
|
||||
def test_inductor(self):
|
||||
max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
|
||||
head_size, dtype = 16, 8, 1024, 16, 32, 2048, 128, torch.float16
|
||||
quant_mode, quant_bit = 0, 8
|
||||
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
|
||||
cache_mem_len, head_size, head_size, 1, dtype, quant_mode, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args)
|
||||
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
|
||||
cache_mem_len, head_size, head_size, 0, dtype, quant_mode, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestDequantFromLinearCache))
|
||||
286
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_paged_cache.py
Executable file
286
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_paged_cache.py
Executable file
@@ -0,0 +1,286 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
import torch_mlu_ops as ops
|
||||
|
||||
from common_utils import *
|
||||
|
||||
def gen_args(batch_size,
|
||||
max_context_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
cache_mem_len,
|
||||
head_size,
|
||||
group_size,
|
||||
block_size,
|
||||
use_seq_offset,
|
||||
dtype,
|
||||
quant_mode,
|
||||
quant_bit,
|
||||
has_value = True):
|
||||
# Preprocess arguments
|
||||
assert cache_mem_len >= max_context_len, \
|
||||
"cache_mem_len should greater then or equal to max_context_len."
|
||||
assert head_size % group_size == 0, \
|
||||
"head_size should be a multiply of groupwise."
|
||||
total_heads = head_num_q + head_num_kv * 2
|
||||
max_seq_offset = cache_mem_len - max_context_len
|
||||
max_block_num = int(math.ceil(max_context_len / block_size))
|
||||
total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
|
||||
block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
|
||||
max_block_num)
|
||||
# Generates key and cache from context
|
||||
context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1,
|
||||
dtype=torch.int32, device="mlu")
|
||||
if use_seq_offset:
|
||||
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device="mlu")
|
||||
else:
|
||||
context_paddings = torch.zeros_like(context_lens)
|
||||
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
|
||||
context_seq_offset[1:] = cu_context_lens[:-1]
|
||||
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
|
||||
key = context[..., head_num_q:head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
|
||||
# Generates key_cache and value_cache
|
||||
cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
|
||||
low=-128, high=127, dtype=torch.int32, device="mlu")
|
||||
cache = cache.to(torch.int8)
|
||||
key_cache, value_cache = cache[[0, 1]]
|
||||
|
||||
# Generates key_cache_scale and value_cache_scale
|
||||
if quant_mode == 0: # quant_mode == 0 is per channel
|
||||
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
|
||||
else: # quant_mode != 1 (== 1 for extend) is per head
|
||||
cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
|
||||
dtype=torch.float, device="mlu")
|
||||
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
|
||||
|
||||
# Prepare arguments
|
||||
if has_value == False:
|
||||
value = None
|
||||
value_cache = None
|
||||
value_cache_scale = None
|
||||
args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale]
|
||||
args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None,
|
||||
block_tables]
|
||||
args += [quant_mode, quant_bit]
|
||||
return args
|
||||
|
||||
class TestDequantFromPagedCache(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
key = create_tensor_from_dic(dic['key'])
|
||||
value = create_tensor_from_dic(dic['value'])
|
||||
key_cache = create_tensor_from_dic(dic['key_cache'])
|
||||
value_cache = create_tensor_from_dic(dic['value_cache'])
|
||||
key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale'])
|
||||
value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale'])
|
||||
context_lengths = dic['context_lengths']['data']
|
||||
max_context_len = dic['max_context_len']['data']
|
||||
context_seq_offset = dic['context_seq_offset']['data']
|
||||
block_tables = dic['block_tables']['data']
|
||||
quant_mode = dic['quant_mode']['data']
|
||||
quant_bit = dic['quant_bit']['data']
|
||||
|
||||
self.launch(key, value, key_cache, value_cache, key_cache_quant_scale,
|
||||
value_cache_quant_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, block_tables, quant_mode, quant_bit)
|
||||
|
||||
def launch(self, *args):
|
||||
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
|
||||
context_lengths, max_context_len, context_seq_offset, block_tables, \
|
||||
quant_mode, quant_bit = args
|
||||
if value is None or value_cache is None or value_cache_scale is None:
|
||||
has_value = False
|
||||
else:
|
||||
has_value = True
|
||||
ops.dequant_from_paged_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, block_tables, quant_mode, quant_bit)
|
||||
key_clone = key.clone()
|
||||
if has_value:
|
||||
value_clone = value.clone()
|
||||
self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lengths, max_context_len,
|
||||
context_seq_offset, block_tables, quant_mode, quant_bit)
|
||||
self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True)
|
||||
if has_value:
|
||||
self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001,
|
||||
use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
def dequant_from_cache(quant_data: torch.Tensor,
|
||||
scale_data: torch.Tensor,
|
||||
quant_mode: int):
|
||||
quant_data_fp32 = quant_data.clone().to(torch.float)
|
||||
scale_data_fp32 = scale_data.clone().to(torch.float)
|
||||
if quant_mode == 0: # per channel [head_num, 1, head_size]
|
||||
scale_data_fp32 = scale_data[..., None, :]
|
||||
else: # per head/token [head_num, context_len, 1]
|
||||
scale_data_fp32 = scale_data[..., None]
|
||||
dequant_data_fp32 = quant_data_fp32 * scale_data_fp32
|
||||
|
||||
return dequant_data_fp32
|
||||
|
||||
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
|
||||
context_lengths, max_context_len, context_seq_offset, block_tables, \
|
||||
quant_mode, quant_bit = args
|
||||
batch_size = context_lengths.size(0)
|
||||
if context_seq_offset is None:
|
||||
cu_seq_offset = torch.cumsum(context_lengths, dim=-1)
|
||||
context_seq_offset = torch.zeros_like(cu_seq_offset)
|
||||
context_seq_offset[1:] = cu_seq_offset[:-1]
|
||||
|
||||
total_seqlen = 0
|
||||
block_size = key_cache.size(2)
|
||||
for i in range(batch_size):
|
||||
context_len = context_lengths[i].item()
|
||||
seq_begin = context_seq_offset[i].item()
|
||||
seq_end = seq_begin + context_len
|
||||
total_seqlen += context_len
|
||||
full_block_num = context_len // block_size
|
||||
rem_token_num = context_len % block_size
|
||||
key_i = key[seq_begin:seq_end].transpose(1, 0)
|
||||
# [head_num, seq_num, head_size]
|
||||
key_cache_i = torch.concat(
|
||||
[key_cache[block_tables[i, j], ...] for j in range(full_block_num)] +
|
||||
([key_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \
|
||||
if rem_token_num > 0 else []), dim=-2
|
||||
)
|
||||
|
||||
# dequant key from cache
|
||||
if quant_mode == 0:
|
||||
# [head_num, head_size]
|
||||
key_cache_scale_i = key_cache_scale
|
||||
else:
|
||||
# [head_num, seq_num]
|
||||
key_cache_scale_i = torch.concat(
|
||||
[key_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] +
|
||||
([key_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \
|
||||
if rem_token_num > 0 else []), dim=-1
|
||||
)
|
||||
|
||||
dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode)
|
||||
key_i[...] = dequant_key_i.to(key_i.dtype)
|
||||
# dequant value from cache
|
||||
if not (value_cache is None or value is None or value_cache_scale is None):
|
||||
value_i = value[seq_begin:seq_end].transpose(1, 0)
|
||||
# [head_num, seq_num, head_size]
|
||||
value_cache_i = torch.concat(
|
||||
[value_cache[block_tables[i, j], ...] for j in range(full_block_num)] +
|
||||
([value_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \
|
||||
if rem_token_num > 0 else []), dim=-2
|
||||
)
|
||||
|
||||
# dequant value from cache
|
||||
if quant_mode == 0:
|
||||
# [head_num, head_size]
|
||||
value_cache_scale_i = value_cache_scale
|
||||
else:
|
||||
# [head_num, seq_num]
|
||||
value_cache_scale_i = torch.concat(
|
||||
[value_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] +
|
||||
([value_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \
|
||||
if rem_token_num > 0 else []), dim=-1
|
||||
)
|
||||
dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i, quant_mode)
|
||||
value_i[...] = dequant_value_i.to(value_i.dtype)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ or "MLU3" in torch.mlu.get_device_name(),
|
||||
"Skipping test_prevent due to ASan issues or in MLU300 series.")
|
||||
def test_prevent(self):
|
||||
batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
|
||||
head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16
|
||||
dtype, quant_mode, quant_bit = torch.float16, 0, 8
|
||||
default_args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit)
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 1, torch.float32, quant_mode, quant_bit)
|
||||
self.assertException("Tensor key type should be half or bfloat16.",
|
||||
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 1, dtype, 2, quant_bit)
|
||||
self.assertException("quantization mode support 0 and 1.",
|
||||
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 1, dtype, quant_mode, 4)
|
||||
self.assertException("quantization bit width only supports 8.",
|
||||
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
|
||||
default_args[-5] = 10240
|
||||
self.assertException("max_context_len should smaller than or equal to block_size * max_block_num.",
|
||||
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args)
|
||||
default_args[-5] = max_context_len
|
||||
default_args[1] = default_args[1].transpose(-1, -2)
|
||||
self.assertException("Tensor value last dim must be contiguous.",
|
||||
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args)
|
||||
|
||||
@unittest.skipIf("MLU3" in torch.mlu.get_device_name(),
|
||||
"Skipping test_dequant_from_paged_cache in MLU300 series")
|
||||
def test_dequant_from_paged_cache(self):
|
||||
test_cases = 100
|
||||
head_size_times = 16
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
|
||||
max_context_len_list = torch.randint(low=2, high=1024, size=[test_cases],
|
||||
dtype=torch.int32)
|
||||
head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
|
||||
head_num_kv_list = torch.randint(low=1, high=8, size=[test_cases], dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=[test_cases],
|
||||
dtype=torch.int32) * head_size_times
|
||||
block_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
|
||||
cache_mem_len_list = torch.randint(low=512, high=1024, size=[test_cases],
|
||||
dtype=torch.int32) * 2
|
||||
quant_mode_list = np.random.choice([0, 1], test_cases)
|
||||
quant_bit_list = np.random.choice([8], test_cases)
|
||||
use_offset_list = np.random.choice([False, True], test_cases)
|
||||
has_value_list = np.random.choice([False, True], test_cases)
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
dtype_list = np.random.choice(dtype_list, test_cases)
|
||||
for i in range(test_cases):
|
||||
batch_size = batch_size_list[i].item()
|
||||
head_num_q = head_num_q_list[i].item()
|
||||
max_context_len = max_context_len_list[i].item()
|
||||
head_num_kv = head_num_kv_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
block_size = block_size_list[i].item()
|
||||
cache_mem_len = cache_mem_len_list[i].item()
|
||||
quant_mode = quant_mode_list[i]
|
||||
quant_bit = quant_bit_list[i]
|
||||
use_seq_offset = use_offset_list[i]
|
||||
has_value = has_value_list[i]
|
||||
dtype = dtype_list[i]
|
||||
print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, "
|
||||
"quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size,
|
||||
max_context_len, quant_mode, quant_bit, dtype))
|
||||
|
||||
torch.manual_seed(2766)
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, use_seq_offset, dtype, quant_mode,
|
||||
quant_bit, has_value)
|
||||
self.launch(*args)
|
||||
|
||||
@unittest.skipIf("MLU3" in torch.mlu.get_device_name(),
|
||||
"Skipping test_dequant_from_paged_cache in MLU300 series")
|
||||
def test_inductor(self):
|
||||
batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
|
||||
head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16
|
||||
dtype, quant_mode, quant_bit = torch.float16, 0, 8
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args)
|
||||
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
|
||||
head_size, head_size, block_size, 0, dtype, quant_mode, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestDequantFromPagedCache))
|
||||
299
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_faketensor.py
Normal file
299
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_faketensor.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
from itertools import product
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
import torch.testing._internal.optests as optests
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
dtype_dict = {
|
||||
torch.half: "half",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float: "float",
|
||||
}
|
||||
|
||||
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
|
||||
|
||||
class FakeTensorTest(TestCase):
|
||||
def test_matmul(self):
|
||||
with fake_tensor_mode:
|
||||
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
|
||||
trans_a_list = [True, False]
|
||||
trans_b_list = [True, False]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
args = product(trans_a_list, trans_b_list, dtype_list)
|
||||
for trans_a, trans_b, dtype in args:
|
||||
print(f"matmul... trans_a: {trans_a}, trans_b: {trans_b}, dtype: {dtype}")
|
||||
shape_a, shape_b = (mat_m, mat_k), (mat_k, mat_n)
|
||||
if trans_a:
|
||||
shape_a = (mat_k, mat_m)
|
||||
if trans_b:
|
||||
shape_b = (mat_n, mat_k)
|
||||
a = torch.randn(shape_a, dtype=dtype, device='mlu')
|
||||
b = torch.randn(shape_b, dtype=dtype, device='mlu')
|
||||
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
||||
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
output = torch.ops.torch_mlu_ops.matmul(a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
|
||||
self.assertEqual(output.shape, (mat_m, mat_n))
|
||||
|
||||
a8 = torch.randint(-128, 127, shape_a).to(torch.int8).mlu()
|
||||
b8 = torch.randint(-128, 127, shape_b).to(torch.int8).mlu()
|
||||
a_scale = 280.8
|
||||
b_scale = 190.8
|
||||
str_dtype = dtype_dict[dtype]
|
||||
output = torch.ops.torch_mlu_ops.matmul(a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
|
||||
self.assertEqual(output.shape, (mat_m, mat_n))
|
||||
|
||||
def test_weight_only_quant_matmul(self):
|
||||
with fake_tensor_mode:
|
||||
M, K, N, group_num = 2, 256, 32, 4
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta = 8, 'none', True, 1., 0.8, 0.3
|
||||
group_quant_list = [True, False]
|
||||
trans_a_list = [True, False]
|
||||
trans_b_list = [True, False]
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
args = product(group_quant_list, dtype_list, trans_a_list, trans_b_list)
|
||||
for group_quant, dtype, trans_a, trans_b in args:
|
||||
print(f"weight_only_quant_matmul... group_quant: {group_quant}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}")
|
||||
a = torch.randn((M, K), dtype=dtype).mlu()
|
||||
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype)
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype)
|
||||
group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype)
|
||||
b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel"
|
||||
b_scale = group_wise_scale if group_quant else None
|
||||
gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float)
|
||||
a_scale, a_zero, b_zero, c_zero = None, None, None, None
|
||||
c_scale, gemm_output_zero = None, None
|
||||
quant_algo, a_quant_layout = "weight_only", "quantize_none"
|
||||
str_dtype = dtype_dict[dtype]
|
||||
output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero,
|
||||
b, b_scale, b_zero,
|
||||
bias, c, c_scale, c_zero,
|
||||
gemm_output_scale, gemm_output_zero,
|
||||
str_dtype, None, quant_algo,
|
||||
a_quant_layout, b_quant_layout,
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef,
|
||||
alpha, beta, trans_a, trans_b,)
|
||||
self.assertEqual(output.shape, (M, N))
|
||||
|
||||
def test_smooth_quant_matmul(self):
|
||||
with fake_tensor_mode:
|
||||
act_mode_list = ["none", "silu", "gelu"]
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
trans_a_list = [True, False]
|
||||
trans_b_list = [True, False]
|
||||
M, K, N = 2, 16, 32
|
||||
quant_bit_size, use_hp_active, act_coef, alpha, beta = 8, True, 1., 0.8, 0.3
|
||||
arg = product(act_mode_list, dtype_list, trans_a_list, trans_b_list)
|
||||
for act_mode, dtype, trans_a, trans_b in arg:
|
||||
print(f"smooth_quant_matmul... act_mode: {act_mode}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}")
|
||||
a = torch.randint(-128, 127, (M, K), dtype=torch.int8).mlu()
|
||||
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
|
||||
c = None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype)
|
||||
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
|
||||
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
|
||||
a_zero, b_zero, c_zero = None, None, None
|
||||
c_scale, gemm_output_scale, gemm_output_zero = None, None, None
|
||||
quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel"
|
||||
str_dtype = dtype_dict[dtype]
|
||||
output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero,
|
||||
b, b_scale, b_zero,
|
||||
bias, c, c_scale, c_zero,
|
||||
gemm_output_scale, gemm_output_zero,
|
||||
str_dtype, None, quant_algo,
|
||||
a_quant_layout, b_quant_layout,
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef,
|
||||
alpha, beta, trans_a, trans_b)
|
||||
self.assertEqual(output.shape, (M, N))
|
||||
|
||||
def test_group_gemm(self):
|
||||
with fake_tensor_mode:
|
||||
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False, True]
|
||||
has_bias_list = [True, False]
|
||||
args = product( dtype_list, idx_list, has_bias_list)
|
||||
for data_type, idx_mode, has_bias in args:
|
||||
print(f"group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}")
|
||||
bs = batch * seq
|
||||
token_topk = bs * topk
|
||||
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
|
||||
sorted_expert_id, indices = expert_id.sort()
|
||||
gather_idx = indices // topk
|
||||
gather_idx = gather_idx.to(torch.int32)
|
||||
token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32)
|
||||
a = torch.randn(bs, k, device="mlu", dtype=data_type)
|
||||
if not idx_mode:
|
||||
a = a[gather_idx]
|
||||
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
|
||||
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
|
||||
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
a_scale = None
|
||||
b_scale = None
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
|
||||
gather_idx_ = gather_idx if idx_mode else None
|
||||
output = torch.ops.torch_mlu_ops.group_gemm(a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, None, None, None, bs)
|
||||
self.assertEqual(output.shape, (token_topk, n))
|
||||
|
||||
def test_smoothquant_group_gemm(self):
|
||||
with fake_tensor_mode:
|
||||
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False, True]
|
||||
has_bias_list = [True, False]
|
||||
args = product( dtype_list, idx_list, has_bias_list)
|
||||
for data_type, idx_mode, has_bias in args:
|
||||
print(f"smoothquant_group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}")
|
||||
bs = batch * seq
|
||||
token_topk = bs * topk
|
||||
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
|
||||
sorted_expert_id, indices = expert_id.sort()
|
||||
gather_idx = indices // topk
|
||||
gather_idx = gather_idx.to(torch.int32)
|
||||
token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32)
|
||||
a8 = torch.randint(-128, 127, (bs, k)).to(torch.int8).mlu()
|
||||
b8 = torch.randint(-128, 127, (experts_num, n, k)).to(torch.int8).mlu()
|
||||
a_scale = torch.randn(token_topk, dtype=torch.float32).mlu()
|
||||
b_scale = torch.randn(experts_num, n, dtype=torch.float32).mlu()
|
||||
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
|
||||
if not idx_mode:
|
||||
a8 = a8[gather_idx]
|
||||
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
|
||||
gather_idx_ = gather_idx if idx_mode else None
|
||||
str_dtype = dtype_dict[data_type]
|
||||
output = torch.ops.torch_mlu_ops.group_gemm(a8, b8, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, str_dtype, None, None, bs)
|
||||
self.assertEqual(output.shape, (token_topk, n))
|
||||
|
||||
def test_moe_expand_input(self):
|
||||
with fake_tensor_mode:
|
||||
token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 2048, 4096, 32, 8, 3, 20
|
||||
dtype_list = [torch.half, torch.float, torch.int8, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half, torch.float, torch.int8]
|
||||
for dtype in dtype_list:
|
||||
print(f"moe_expand_input... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}, dtype: {dtype}")
|
||||
input = torch.randn(token_num, hidden_size, device='mlu').to(dtype)
|
||||
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu')
|
||||
cusum_token_count = torch.zeros(expert_num + 1, dtype=torch.int32).mlu()
|
||||
output=torch.ops.torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
self.assertEqual(output.shape, (token_num*topk, hidden_size))
|
||||
|
||||
def test_moe_gen_idx(self):
|
||||
with fake_tensor_mode:
|
||||
token_num, expert_num, topk = 2048, 32, 8
|
||||
print(f"moe_gen_idx... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}")
|
||||
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
|
||||
expand_idx, combine_idx, token_count, cusum_token_count =torch.ops.torch_mlu_ops.moe_gen_idx(expert_id, expert_num)
|
||||
self.assertEqual(expand_idx.shape, (token_num * topk,))
|
||||
self.assertEqual(combine_idx.shape, (token_num * topk,))
|
||||
self.assertEqual(token_count.shape, (expert_num,))
|
||||
self.assertEqual(cusum_token_count.shape, (expert_num + 1,))
|
||||
|
||||
def test_moe_combine_result(self):
|
||||
with fake_tensor_mode:
|
||||
has_bias_list = [True, False]
|
||||
num_tokens, hidden_size, num_expert, topk, start_expert_id = 1, 2048, 8, 2, 0
|
||||
expert_size_list = [5, 8]
|
||||
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
|
||||
args = product(has_bias_list, expert_size_list, dtype_list)
|
||||
for has_bias, expert_size, dtype in args:
|
||||
print(f"moe_combine_result... has_bias: {has_bias}, expert_size: {expert_size}, dtype: {dtype}")
|
||||
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device='mlu')
|
||||
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device='mlu')
|
||||
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device='mlu')
|
||||
bias = None
|
||||
residual = None
|
||||
cusum_token_count = None
|
||||
if has_bias:
|
||||
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device='mlu')
|
||||
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device='mlu')
|
||||
if has_bias or expert_size < num_expert:
|
||||
cusum_token_count =torch.zeros(num_expert + 1, dtype=torch.int32).mlu()
|
||||
output=torch.ops.torch_mlu_ops.moe_combine_result(input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias)
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
|
||||
def test_moe(self):
|
||||
with fake_tensor_mode:
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while (len(case_list) < 100):
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(1024, 3072, 512)
|
||||
inner_size = random.randrange(1024, 3072, 512)
|
||||
expert_num = random.randint(1, 40)
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
quant_mode = random.choice(["no_quant", "w4", "w8", "w4w8"])
|
||||
quant_wise = random.choice([128, 256, 512])
|
||||
data_type = random.choice([torch.bfloat16, torch.float16])
|
||||
if not torch_mlu.mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_mode, quant_wise, act_mode)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_mode: {quant_mode}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
|
||||
# if expert_size == -1:
|
||||
# expert_size = expert_num
|
||||
if quant_mode == "no_quant":
|
||||
w1 = torch.randn((expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
||||
w2 = torch.randn((expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
||||
w1_quant_flag = None
|
||||
w2_quant_flag = None
|
||||
elif quant_mode == "w4":
|
||||
w1_quant_group = hidden_size // quant_wise
|
||||
w2_quant_group = inner_size // quant_wise
|
||||
w1_quant_flag = None
|
||||
w2_quant_flag = None
|
||||
w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size // 2), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size // 2), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
elif quant_mode == "w8":
|
||||
w1_quant_flag = None
|
||||
w2_quant_flag = None
|
||||
w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w1_scale = torch.empty((expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
w2_scale = torch.empty((expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
elif quant_mode == "w4w8":
|
||||
w1_quant_group = hidden_size // quant_wise
|
||||
w2_quant_group = inner_size // quant_wise
|
||||
w1_quant_flag = random.choices([4,8], k=expert_num * w1_quant_group)
|
||||
w2_quant_flag = random.choices([4,8], k=expert_num * w2_quant_group)
|
||||
w1_count = (sum(w1_quant_flag) // 4) * (quant_wise // 2) * inner_size*(1+gated)
|
||||
w2_count = (sum(w2_quant_flag) // 4) * (quant_wise // 2) * hidden_size
|
||||
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
input_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
||||
act_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
||||
bias1, bias2 = None, None
|
||||
output = torch.ops.torch_mlu_ops.fused_moe(hidden_states, router_logit, w1, w2, bias1, bias2, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale, w1_quant_flag,
|
||||
w2_quant_flag, topk, renormalize, gated, act_mode, 0,
|
||||
0, 0)
|
||||
self.assertEqual(output.shape, hidden_states.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.__version__ >= '2.3.0':
|
||||
run_tests()
|
||||
65
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py
Executable file
65
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py
Executable file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
|
||||
class TestFFNOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, w1, bias1, w2, bias2, w3, bias3, act_mode = args
|
||||
up = F.linear(input, w1, bias1)
|
||||
act = act_mode_dict[act_mode](up.float()).to(input.dtype)
|
||||
if w3 is not None:
|
||||
gate = F.linear(input, w3, bias3)
|
||||
act = act * gate
|
||||
output = F.linear(act, w2, bias2)
|
||||
return output
|
||||
|
||||
def test_ffn(self):
|
||||
input_size_list = [128, 256, 512]
|
||||
hidden_size = 1024
|
||||
seq_len_list = [10, 16, 20]
|
||||
bool_value_list = [True, False]
|
||||
batch = 5
|
||||
dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for input_size, seq_len, bool_value in zip(input_size_list,
|
||||
seq_len_list,
|
||||
bool_value_list):
|
||||
print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format(
|
||||
input_size, seq_len, bool_value, bool_value), flush=True)
|
||||
use_gate = bool_value
|
||||
for dtype in dtype_list:
|
||||
w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
||||
b1 = torch.randn((hidden_size), dtype=dtype, device="mlu")
|
||||
w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
||||
b2 = torch.randn((input_size), dtype=dtype, device="mlu")
|
||||
w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None
|
||||
b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None
|
||||
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
||||
args = (input, w1, b1, w2, b2, w3, b3, 'silu')
|
||||
output = self.op_impl_base(*args)
|
||||
tmo_output1 = ops.ffn(*args)
|
||||
self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
# use matmul to implement ffn
|
||||
f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1
|
||||
f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1
|
||||
pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0)
|
||||
act_out = ops.active(pre_gemm_out, 'silu', use_gate)
|
||||
tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0)
|
||||
tmo_output2 = tmo_output2.view(batch, seq_len, input_size)
|
||||
self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
|
||||
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
||||
up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
||||
down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
||||
args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFFNOp))
|
||||
357
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_flash_attention.py
Executable file
357
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_flash_attention.py
Executable file
@@ -0,0 +1,357 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
def gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype):
|
||||
batch = len(seq_q)
|
||||
max_seq_q = max(seq_q)
|
||||
max_seq_k = max(seq_k)
|
||||
cu_seq_len_q = [0]
|
||||
cu_seq_len_k = [0]
|
||||
for i in range(batch):
|
||||
cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1])
|
||||
cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1])
|
||||
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu()
|
||||
attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu()
|
||||
total_seq_q = sum(seq_q)
|
||||
total_seq_k = sum(seq_k)
|
||||
q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu")
|
||||
block_tables = None
|
||||
if use_block:
|
||||
block_size = 16
|
||||
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size))
|
||||
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size
|
||||
cache_shape = (num_blocks, head_num_k, block_size, head_size)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
k = torch.randn(size=cache_shape, dtype=dtype).mlu()
|
||||
v = torch.randn(size=cache_shape, dtype=dtype).mlu()
|
||||
else:
|
||||
k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
|
||||
v = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
|
||||
tmo_out = torch.empty_like(q)
|
||||
out_lse = torch.randn(batch, head_num_q, max_seq_q, dtype = torch.float, device="mlu") if return_lse else None
|
||||
return (q, k, v, tmo_out, out_lse,
|
||||
torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu"),
|
||||
torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu"),
|
||||
alibi_slope, attn_bias, None, None,
|
||||
block_tables, max_seq_q, max_seq_k,
|
||||
softmax_scale, is_causal,
|
||||
-1, -1, "float", return_lse)
|
||||
|
||||
def gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype):
|
||||
batch = len(seq_q)
|
||||
max_seq_q = max(seq_q)
|
||||
max_seq_k = max(seq_k)
|
||||
cu_seq_len_q = [0]
|
||||
cu_seq_len_k = [0]
|
||||
for i in range(batch):
|
||||
cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1])
|
||||
cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1])
|
||||
cu_seq_len_q = torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu")
|
||||
cu_seq_len_k = torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu")
|
||||
|
||||
alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu()
|
||||
attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu()
|
||||
total_seq_q = sum(seq_q)
|
||||
total_seq_k = sum(seq_k)
|
||||
q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu")
|
||||
block_tables = None
|
||||
if use_block:
|
||||
block_size = 16
|
||||
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size))
|
||||
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
|
||||
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
|
||||
else:
|
||||
k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
|
||||
v = torch.randn(total_seq_k, head_num_k, head_size_v, dtype=dtype, device="mlu")
|
||||
return q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables
|
||||
|
||||
class TestFlashAttnOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
q, k, v, out, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, max_seq_len_q, \
|
||||
max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, \
|
||||
compute_dtype, return_lse, block_tables, k_cache_quant_scale, v_cache_quant_scale = args
|
||||
is_pack = cu_seq_lens_q is not None
|
||||
has_block_table = block_tables is not None
|
||||
if has_block_table:
|
||||
assert is_pack == True
|
||||
batch = len(cu_seq_lens_q) - 1 if is_pack else q.size(0)
|
||||
head_num_q = q.size(-2)
|
||||
head_num_kv = k.size(-2) if block_tables is None else k.size(-3)
|
||||
head_size = q.size(-1)
|
||||
head_size_v = v.size(-1)
|
||||
assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0
|
||||
group = head_num_q // head_num_kv
|
||||
device = q.device
|
||||
repeat_dim = -3 if has_block_table else -2
|
||||
k_bd = torch.repeat_interleave(k, group, dim=repeat_dim)
|
||||
v_bd = torch.repeat_interleave(v, group, dim=repeat_dim)
|
||||
out_list = []
|
||||
inf = 1e6
|
||||
lse = torch.zeros(batch, head_num_q, max_seq_len_q, dtype=torch.float)
|
||||
for i in range(batch):
|
||||
q_i = q[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...] if is_pack else q[i]
|
||||
if not has_block_table:
|
||||
k_i = k_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else k[i]
|
||||
v_i = v_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else v[i]
|
||||
else:
|
||||
block_table = block_tables[i]
|
||||
context_len = cu_seq_lens_kv[i+1] - cu_seq_lens_kv[i]
|
||||
block_size = k.size(-2) # (num_block, head_num, block_size, head_size)
|
||||
table_end = (context_len + block_size - 1) // block_size
|
||||
block_ids = block_table[0 : table_end]
|
||||
keys, values = k[block_ids], v[block_ids]
|
||||
keys = torch.repeat_interleave(keys, group, dim=1) #[num_block, head_q, block_size, head_size]
|
||||
keys = keys.transpose(1, 0).contiguous().view(head_num_q, -1, head_size) #[head_q, num_block * blcke_size, head_size]
|
||||
k_i = keys.transpose(1,0) #[num_block * blcke_size, head_q, head_size]
|
||||
k_i = k_i[0:context_len, ...] #[seq_k, head_q, head_size]
|
||||
values = torch.repeat_interleave(values, group, dim=1)
|
||||
values = values.transpose(1, 0).contiguous().view(head_num_q, -1, head_size_v)
|
||||
v_i = values.transpose(1,0)
|
||||
v_i = v_i[0:context_len, ...]
|
||||
qk = torch.einsum('qhd,khd->hqk', q_i, k_i).to(torch.float) * softmax_scale
|
||||
seq_q, seq_k = q_i.size(0), k_i.size(0)
|
||||
if alibi_slope is not None:
|
||||
slope = alibi_slope.reshape(1, head_num_q, 1, 1)
|
||||
slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device)
|
||||
if is_causal:
|
||||
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device)
|
||||
slope_bias = relative_pos * slope
|
||||
else:
|
||||
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
|
||||
col_idx = torch.arange(seq_k, dtype=torch.long)
|
||||
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device)
|
||||
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
|
||||
qk += (slope_bias.squeeze(0))
|
||||
if is_causal:
|
||||
assert seq_q <= seq_k, "seq_q <= seq_k if causal=True"
|
||||
zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu")
|
||||
tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1)
|
||||
mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k)
|
||||
qk += mask
|
||||
if window_size_left != -1 or window_size_right != -1:
|
||||
mask_w = torch.full((seq_q, seq_k), -inf, dtype=torch.float, device="mlu")
|
||||
for qi in range(seq_q):
|
||||
left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0
|
||||
right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k
|
||||
mask_w[qi, left:right] = 0
|
||||
qk += mask_w
|
||||
|
||||
if attn_bias is not None:
|
||||
qk += attn_bias[i][:, :seq_q, :seq_k]
|
||||
if return_lse:
|
||||
lse[i][:, :seq_q] = torch.logsumexp(qk, dim=-1)
|
||||
attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype)
|
||||
qkv = torch.einsum('hqk,khd->qhd', attn, v_i)
|
||||
out_list.append(qkv)
|
||||
attn_out = torch.cat(out_list, dim=0)
|
||||
if is_pack == False:
|
||||
attn_out = attn_out.view(q.size(0), q.size(1), q.size(2), head_size_v)
|
||||
if return_lse:
|
||||
attn_out = (attn_out, lse)
|
||||
return attn_out
|
||||
|
||||
def test_flash_attention(self):
|
||||
seq_len_list = [((38, 64, 128), (38, 64, 128)), ((30, 40, 50), (60, 90, 120))]
|
||||
head_num_list = [(32, 32), (32, 4)]
|
||||
use_block_list = [False, True]
|
||||
head_size_list = [(64, 64), (128, 512)]
|
||||
alibi_list = [False, True]
|
||||
mask_list = [False, True]
|
||||
causal_list = [False, True]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
window_size_list = [(-1, -1), (10, -1), (8, 8)]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
use_block_list = [False]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
args = product(seq_len_list, head_num_list, head_size_list, alibi_list, mask_list, causal_list, use_block_list, window_size_list, dtype_list)
|
||||
for ((seq_q, seq_k), (head_num_q, head_num_k), (head_size, head_size_v), has_alibi, has_mask, is_causal, use_block,
|
||||
(window_size_left, window_size_right), dtype) in args:
|
||||
batch = len(seq_q)
|
||||
print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v= {}, has_alibi={}, \
|
||||
has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format(
|
||||
batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block,
|
||||
window_size_left, window_size_right, dtype), flush=True)
|
||||
max_seq_q = max(seq_q)
|
||||
max_seq_k = max(seq_k)
|
||||
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype)
|
||||
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
|
||||
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
if use_block: # test block_tables is [batch, 1]
|
||||
block_size = max_seq_k
|
||||
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
|
||||
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
|
||||
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
|
||||
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, -1, -1, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, -1, -1, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_flash_attention_lse(self):
|
||||
seq_len_list =[((66, 77, 88), (77, 88, 99)), ((1024, 1024), (1024, 1024))]
|
||||
head_num_q, head_num_k = 16, 8
|
||||
head_size_list = [(64, 64), (16, 128)]
|
||||
is_causal = True
|
||||
has_alibi = True
|
||||
has_mask = True
|
||||
window_size_list = [(-1, -1), (10, -1)]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
use_block_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
use_block_list = [False]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
args = product(seq_len_list, head_size_list, use_block_list, window_size_list, dtype_list)
|
||||
for (seq_q, seq_k), (head_size, head_size_v), use_block, (window_size_left, window_size_right), dtype in args:
|
||||
batch = len(seq_q)
|
||||
print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v={}, has_alibi={}, \
|
||||
has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format(
|
||||
batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block,
|
||||
window_size_left, window_size_right, dtype), flush=True)
|
||||
max_seq_q = max(seq_q)
|
||||
max_seq_k = max(seq_k)
|
||||
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype)
|
||||
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
|
||||
torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, True,
|
||||
block_tables if use_block else None, None, None)
|
||||
tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, True,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True)
|
||||
|
||||
if use_block: #test block_table = [batch, 1]
|
||||
block_size = max_seq_k
|
||||
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
|
||||
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
|
||||
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
|
||||
torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, True,
|
||||
block_tables if use_block else None, None, None)
|
||||
tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, True,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_flash_attention_inplace(self):
|
||||
seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v = (66, 77, 88), (77, 88, 99), 16, 8, 64, 64
|
||||
use_block_list = [False, True]
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
is_causal = True
|
||||
has_alibi = True
|
||||
has_mask = True
|
||||
batch = len(seq_q)
|
||||
max_seq_q = max(seq_q)
|
||||
max_seq_k = max(seq_k)
|
||||
cu_seq_len_q = [0]
|
||||
cu_seq_len_k = [0]
|
||||
window_size_left, window_size_right = -1, -1
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
use_block_list = [False]
|
||||
total_seq_q = sum(seq_q)
|
||||
for use_block in use_block_list:
|
||||
print("test_flash_attention_inplace: use_block: {}, testing....".format(use_block))
|
||||
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, torch.half)
|
||||
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
|
||||
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
tmo_output = torch.empty((total_seq_q, head_num_q, head_size_v), dtype=torch.half, device="mlu")
|
||||
ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
if use_block: #test block_table = [batch, 1]
|
||||
block_size = max_seq_k
|
||||
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
|
||||
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=torch.float16).mlu()
|
||||
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=torch.float16).mlu()
|
||||
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k,
|
||||
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
|
||||
is_causal, window_size_left, window_size_right, torch.float, False,
|
||||
block_tables if use_block else None, None, None)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
seq_q, seq_k, head_num_q, head_num_k, head_size, dtype = (66, 77, 88), (77, 88, 99), 16, 8, 64, torch.float16
|
||||
use_block_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
use_block_list = [False]
|
||||
alibi_list = [False, True]
|
||||
mask_list = [False, True]
|
||||
causal_list = [False, True]
|
||||
return_lse_list = [False, True]
|
||||
test_flags = product(use_block_list, alibi_list, mask_list, causal_list, return_lse_list)
|
||||
for use_block, has_alibi, has_mask, is_causal, return_lse in test_flags:
|
||||
print(f"==== use_block: {use_block}, has_alibi: {has_alibi}, has_mask: {has_mask}, is_causal: {is_causal} return_lse: {return_lse}====")
|
||||
args = gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.flash_attention, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFlashAttnOp))
|
||||
94
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_attn_proj.py
Executable file
94
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_attn_proj.py
Executable file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
class TestFusedNormAttnProjOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias, norm_weight, \
|
||||
norm_bias, eps, out_layout, head_size, norm_out = args
|
||||
input_size = input.size(-1)
|
||||
layernorm = torch.nn.LayerNorm(input_size)
|
||||
layernorm.eps = eps
|
||||
layernorm.weight = nn.Parameter(norm_weight)
|
||||
layernorm.bias = nn.Parameter(norm_bias)
|
||||
layernorm_out = layernorm(input)
|
||||
q_out = torch.matmul(layernorm_out, q_weight.permute(1, 0)) + q_bias
|
||||
if k_weight is not None:
|
||||
k_out = torch.matmul(layernorm_out, k_weight.permute(1, 0)) + k_bias
|
||||
v_out = torch.matmul(layernorm_out, v_weight.permute(1, 0)) + v_bias
|
||||
|
||||
if out_layout == 'nhtc':
|
||||
batch, seq, _ = input.shape
|
||||
hidden_size_q = q_weight.size(0)
|
||||
q_head = hidden_size_q // head_size
|
||||
q_out = q_out.reshape(batch, seq, q_head, head_size).transpose(1, 2)
|
||||
if k_weight is not None:
|
||||
hidden_size_kv = k_weight.size(0)
|
||||
kv_head = hidden_size_kv // head_size
|
||||
k_out = k_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2)
|
||||
v_out = v_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2)
|
||||
outs = (q_out,) if k_weight is None else (q_out, k_out, v_out,)
|
||||
if norm_out is True:
|
||||
outs += (layernorm_out,)
|
||||
return outs
|
||||
|
||||
def test_attn_proj(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
N, T, input_size, hidden_size, head_size, eps, alpha, beta = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3
|
||||
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
|
||||
N, T, input_size, hidden_size), flush=True)
|
||||
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
|
||||
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
|
||||
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
|
||||
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
residual = torch.randn(N, T, hidden_size, dtype=dtype, device="mlu")
|
||||
weights = torch.chunk(weight, 3)
|
||||
biass = torch.chunk(bias, 3)
|
||||
# test pre_attn_proj
|
||||
print("test pre_attn_proj...")
|
||||
out_torch = self.op_impl_base(input,
|
||||
weights[0], biass[0],
|
||||
weights[1], biass[1],
|
||||
weights[2], biass[2],
|
||||
norm_weight, norm_bias,
|
||||
eps, 'nthc', head_size, True)
|
||||
out_tmo = ops.fused_norm_attention_project(input,
|
||||
weights[0], biass[0],
|
||||
weights[1], biass[1],
|
||||
weights[2], biass[2],
|
||||
norm_weight, norm_bias,
|
||||
eps, 'nthc', head_size, True)
|
||||
for o1, o2 in list(zip(out_torch, out_tmo)):
|
||||
self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
N, T, input_size, hidden_size, head_size, eps, alpha, beta, dtype = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3, torch.half
|
||||
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
|
||||
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
|
||||
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
|
||||
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
weights = torch.chunk(weight, 3)
|
||||
biass = torch.chunk(bias, 3)
|
||||
args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2],
|
||||
norm_weight, norm_bias, None, "nhtc", head_size, eps, alpha,
|
||||
beta, True)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
|
||||
|
||||
args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2],
|
||||
norm_weight, norm_bias, None, "nthc", head_size, eps, alpha,
|
||||
beta, True)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFusedNormAttnProjOp))
|
||||
89
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_ffn.py
Executable file
89
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_ffn.py
Executable file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
|
||||
class TestFusedNormResidualFFNoP(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, gate_up_proj_weight, \
|
||||
gate_up_proj_bias, layernorm_weight, layernorm_bias, eps, act_mode, residual_is, \
|
||||
alpha, beta = args
|
||||
hidden_size = input.size(-1)
|
||||
layernorm = torch.nn.LayerNorm(hidden_size)
|
||||
layernorm.weight = torch.nn.Parameter(layernorm_weight)
|
||||
layernorm.bias = torch.nn.Parameter(layernorm_bias)
|
||||
layernorm.eps = eps
|
||||
act = act_mode_dict[act_mode]
|
||||
|
||||
residual = input
|
||||
norm_out = input
|
||||
if layernorm_weight is not None:
|
||||
norm_out = layernorm(input)
|
||||
if residual_is == "normed_input":
|
||||
residual = norm_out
|
||||
up_fc_out = torch.matmul(norm_out, up_fc_weight.permute(1, 0)) + up_fc_bias
|
||||
act_out = act(up_fc_out.float()).to(up_fc_out.dtype)
|
||||
if gate_up_proj_weight is not None:
|
||||
gate_up_proj_out = torch.matmul(norm_out, gate_up_proj_weight.permute(1, 0)) + gate_up_proj_bias
|
||||
down_proj_out = torch.matmul(act_out * gate_up_proj_out, down_proj_weight.permute(1, 0)) + down_proj_bias
|
||||
else:
|
||||
down_proj_out = torch.matmul(act_out, down_proj_weight.permute(1, 0)) + down_proj_bias
|
||||
if residual_is != 'none':
|
||||
out = beta * residual + alpha * down_proj_out
|
||||
else:
|
||||
out = down_proj_out
|
||||
return out
|
||||
|
||||
def test_ffn(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
batch, seq_len, hidden_size, inner_size, alpha, beta = 4, 16, 512, 512, 0.5, 0.3
|
||||
eps, act_mode, residual_is = 1e-5, "silu", "input"
|
||||
print(f"batch: {batch}, seq_len: {seq_len}, hidden_size: {hidden_size}, inner_size: {inner_size}, alpha: {alpha}, beta: {beta}, \
|
||||
eps: {eps}, act_mode: {act_mode}, residual_is: {residual_is}, dtype: {dtype} testing...", flush=True)
|
||||
input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu")
|
||||
layernorm_weight = torch.randn(hidden_size, dtype=dtype, device="mlu")
|
||||
layernorm_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu")
|
||||
up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
|
||||
up_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu")
|
||||
gated_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
|
||||
gated_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu")
|
||||
down_fc_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu")
|
||||
down_fc_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu")
|
||||
torch_output = self.op_impl_base(input,
|
||||
up_fc_weight, up_fc_bias,
|
||||
down_fc_weight, down_fc_bias,
|
||||
gated_fc_weight, gated_fc_bias,
|
||||
layernorm_weight, layernorm_bias,
|
||||
eps, act_mode, residual_is,
|
||||
alpha, beta)
|
||||
tmo_output = ops.fused_norm_residual_ffn(input,
|
||||
up_fc_weight, up_fc_bias,
|
||||
down_fc_weight, down_fc_bias,
|
||||
gated_fc_weight, gated_fc_bias,
|
||||
layernorm_weight, layernorm_bias,
|
||||
eps, act_mode, residual_is,
|
||||
alpha, beta)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq_len, hidden_size, inner_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
|
||||
input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu")
|
||||
layernorm_weight = torch.randn((hidden_size,), dtype=dtype, device="mlu")
|
||||
layernorm_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu")
|
||||
up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
|
||||
gate_up_proj_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
|
||||
down_proj_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu")
|
||||
up_fc_bias = torch.randn((inner_size,), dtype=dtype, device="mlu")
|
||||
gate_up_proj_bias = torch.randn((inner_size,), dtype=dtype, device="mlu")
|
||||
down_proj_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu")
|
||||
args = (input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias,
|
||||
gate_up_proj_weight, gate_up_proj_bias, layernorm_weight, layernorm_bias, act_mode, "normed_input", 1e-5, 1., 0.)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFusedNormResidualFFNoP))
|
||||
213
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_layernorm.py
Executable file
213
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_layernorm.py
Executable file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from typing import Union, List, Tuple
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch import Size
|
||||
import os
|
||||
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
eps = 0.001
|
||||
|
||||
class TestFuseLayerNormOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args
|
||||
layernorm = torch.nn.LayerNorm(x.size(-1))
|
||||
layernorm.eps = eps
|
||||
layernorm.weight = Parameter(gamma)
|
||||
layernorm.bias = Parameter(beta)
|
||||
x = x + bias if bias is not None else x
|
||||
pro_input = x + residual if residual is not None else x
|
||||
output = layernorm(pro_input)
|
||||
if quant_scale is not None:
|
||||
output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8)
|
||||
if out is None:
|
||||
if store_output_before_norm:
|
||||
return (output, pro_input)
|
||||
else:
|
||||
return output
|
||||
else:
|
||||
out.copy_(output)
|
||||
return out
|
||||
|
||||
def test_layernorm(self):
|
||||
C = 128
|
||||
input_shape = (8, 8, 6, C)
|
||||
print("test layernorm...")
|
||||
for dtype in dtype_list:
|
||||
torch.manual_seed(1)
|
||||
input = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
beta = torch.randn(C, device="mlu", dtype=dtype)
|
||||
tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma,
|
||||
beta, None, eps, True, None, None, False)
|
||||
torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
|
||||
self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.005, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.005, use_MSE=True)
|
||||
# test output_inplace
|
||||
tmo_output = torch.empty(input_shape, device="mlu", dtype=dtype)
|
||||
torch_output = torch.empty_like(tmo_output)
|
||||
ops.fused_layer_norm(input, residual, gamma, beta, None, eps, False, None, tmo_output, False)
|
||||
self.op_impl_base(input, residual, gamma, beta, None, eps, False, None, torch_output, False)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.003, use_MSE=True)
|
||||
#test input_stride and output-continguous
|
||||
inputs_shape = (8, 8, 10, C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input = inputs[:, :, 0:6, :]
|
||||
tmo_out = ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, None, False)
|
||||
torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
|
||||
# has res
|
||||
tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False)
|
||||
torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True)
|
||||
#res has stride
|
||||
res_shape = (8, 8, 16, C)
|
||||
res = torch.randn(res_shape, device="mlu", dtype=dtype)
|
||||
residual = res[..., 0:6, :]
|
||||
tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False)
|
||||
torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True)
|
||||
#output has different stride
|
||||
outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu')
|
||||
output = outputs[..., 0:6, :]
|
||||
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, output, False)
|
||||
torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False)
|
||||
self.assertTensorsEqual(output.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
|
||||
|
||||
def test_squant_layernorm(self):
|
||||
C = 128
|
||||
input_shape = (8, 8, 6, C)
|
||||
print("test squant_layernorm...")
|
||||
for dtype in dtype_list:
|
||||
torch.manual_seed(1)
|
||||
input = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
beta = torch.randn(C, device="mlu", dtype=dtype)
|
||||
tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, quant_scale, None, False)
|
||||
torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, quant_scale, None, False)
|
||||
self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_layernorm_stride(self):
|
||||
C = 128
|
||||
print("test layernorm stride input...")
|
||||
for dtype in dtype_list:
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
beta = torch.randn(C, device="mlu", dtype=dtype)
|
||||
inputs_shape = (8, 8, 10, C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input = inputs[:, :, 0:6, :]
|
||||
tmo_output = torch.empty(inputs.shape, device="mlu", dtype=dtype)
|
||||
tmo_output = tmo_output.as_strided(input.shape, input.stride())
|
||||
torch_out_0 = torch.empty_like(tmo_output)
|
||||
self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, tmo_output, False)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
|
||||
# test assumption get wrong stride when dim = 1
|
||||
input = inputs[:, :, 8:9, :]
|
||||
input = input.as_strided(input.shape, (10240, 1280, 0, 1))
|
||||
torch_out_0 = torch.empty_like(input)
|
||||
self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, input, False)
|
||||
self.assertTensorsEqual(input.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
|
||||
inputs_shape = (8, 8, 10, 2*C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input1 = inputs[:, :, 0:2, 0:C]
|
||||
torch_out_0 = torch.empty_like(input1)
|
||||
self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False)
|
||||
self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
|
||||
input2 = inputs[:, :, :, 0:C]
|
||||
torch_out_0 = torch.empty_like(input2)
|
||||
self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False)
|
||||
self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
|
||||
#test 3-dim input
|
||||
inputs_shape = (8, 8, 2*C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input1 = inputs[:, 0:2, 0:C]
|
||||
torch_out_0 = torch.empty_like(input1)
|
||||
self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False)
|
||||
self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
|
||||
input2 = inputs[0:2, 0:2, 0:C]
|
||||
torch_out_0 = torch.empty_like(input2)
|
||||
self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False)
|
||||
ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False)
|
||||
self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.0032, use_MSE=True)
|
||||
|
||||
# 防呆测试
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
print("test_prevent....")
|
||||
func1 = ops.fused_layer_norm
|
||||
batch, seq_len, hidden_size = 5, 12, 512
|
||||
input = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("input.dim() >= 2.",
|
||||
func1, input, None, gamma, beta, None, eps, False, None, None)
|
||||
inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[..., ::2]
|
||||
self.assertException("input last dim must be contiguous.",
|
||||
func1, input, None, gamma, beta, None, eps, False, None, None)
|
||||
|
||||
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("layernorm mode need gamma and beta.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, None)
|
||||
gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
|
||||
beta = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("layernorm mode, gamma and beta size must be hidden_size.",
|
||||
func1, input, None, gamma, beta, None, eps, False, None, None)
|
||||
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[..., 6:9, :]
|
||||
self.assertException("quant_out is not support when input has stride.",
|
||||
func1, input, None, gamma, beta, None, eps, False, quant_scale, None)
|
||||
inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[:, 6:9, ...]
|
||||
self.assertException("check the strides of input.",
|
||||
func1, input, None, gamma, beta, None, eps, False, None, input)
|
||||
input = inputs[..., 6:9, :]
|
||||
outputs = torch.randn(batch, 2*seq_len, 3, hidden_size, dtype=torch.half, device='mlu')
|
||||
output = outputs[:, :seq_len, ...]
|
||||
self.assertException("check the strides of output.",
|
||||
func1, input, None, gamma, beta, None, eps, False, None, output)
|
||||
|
||||
def test_inductor(self):
|
||||
print("test_inductor....")
|
||||
batch, seq_len, hidden_size = 5, 12, 512
|
||||
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu')
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
bias = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
residual_out = torch.empty_like(input)
|
||||
|
||||
output = torch.zeros(input.shape, dtype=torch.int8, device='mlu')
|
||||
args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "layernorm", eps, True, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "layernorm", eps, False, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
output = torch.empty_like(input)
|
||||
args = (input, output, residual, gamma, beta, bias, None, None, None, "layernorm", eps, False, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFuseLayerNormOp))
|
||||
202
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rmsnorm.py
Normal file
202
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rmsnorm.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from typing import Union, List, Tuple
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch import Size
|
||||
import os
|
||||
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
eps = 0.001
|
||||
|
||||
class TestFuseRmsNormOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
x = create_tensor_from_dic(dic['x'])
|
||||
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
|
||||
gamma = None if dic['gamma']['data'] is None else create_tensor_from_dic(dic['gamma'])
|
||||
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
|
||||
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
|
||||
eps = dic['eps']['data']
|
||||
store_output_before_norm = dic['store_output_before_norm']['data']
|
||||
quant_scale = None if dic['quant_scale']['data'] is None else create_tensor_from_dic(dic['quant_scale'])
|
||||
out = None if dic['out']['data'] is None else create_tensor_from_dic(dic['out'])
|
||||
dynamic_quant = dic['dynamic_quant']['data']
|
||||
self.launch(x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant)
|
||||
|
||||
def launch(self, *args):
|
||||
args = list(args)
|
||||
base_out = None if args[-2] is None else torch.empty_like(args[-2])
|
||||
base_input = args[0].clone() if args[-2] is not None and args[0] is args[-2] else args[0]
|
||||
tmo_out = ops.fused_rms_norm(*args)
|
||||
args[0] = base_input
|
||||
args[-2] = base_out
|
||||
torch_out = self.op_impl_base(*args)
|
||||
if tmo_out.__class__ in (list, tuple):
|
||||
for o1, o2 in zip(tmo_out, torch_out):
|
||||
self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(), 0.007, use_MSE=True)
|
||||
else:
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.007, use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args
|
||||
x = x + bias if bias is not None else x
|
||||
pro_input = x + residual if residual is not None else x
|
||||
store_input = pro_input
|
||||
pro_input = pro_input.to(torch.float32)
|
||||
variance = pro_input.pow(2).mean(-1, keepdim=True)
|
||||
pro_input = pro_input * torch.rsqrt(variance + eps)
|
||||
if gamma.dtype in [torch.float16, torch.bfloat16]:
|
||||
pro_input = pro_input.to(gamma.dtype)
|
||||
output = gamma * pro_input
|
||||
if quant_scale is not None:
|
||||
output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8)
|
||||
if out is None:
|
||||
if store_output_before_norm:
|
||||
return (output, store_input)
|
||||
else:
|
||||
return output
|
||||
else:
|
||||
out.copy_(output)
|
||||
return out
|
||||
|
||||
def test_rmsnorm(self):
|
||||
C = 128
|
||||
input_shape = (8, 8, 6, C)
|
||||
print("test rmsnorm...")
|
||||
for dtype in dtype_list:
|
||||
torch.manual_seed(1)
|
||||
input = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
|
||||
# test inplace output
|
||||
output = torch.empty(input_shape, device="mlu", dtype=dtype)
|
||||
self.launch(input, residual, gamma, None, None, eps, False, None, output, False)
|
||||
inputs_shape = (8, 8, 10, C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input = inputs[:, :, 0:6, :]
|
||||
self.launch(input, None, gamma, None, None, eps, False, None, None, False)
|
||||
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
|
||||
#res has stride
|
||||
res_shape = (8, 8, 16, C)
|
||||
res = torch.randn(res_shape, device="mlu", dtype=dtype)
|
||||
residual = res[..., 0:6, :]
|
||||
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
|
||||
#output has different stride
|
||||
outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu')
|
||||
output = outputs[..., 0:6, :]
|
||||
self.launch(input, residual, gamma, None, None, eps, False, None, output, False)
|
||||
|
||||
def test_squant_rmsnorm(self):
|
||||
C = 128
|
||||
input_shape = (8, 8, 6, C)
|
||||
print("test squant_rmsnorm...")
|
||||
for dtype in dtype_list:
|
||||
torch.manual_seed(1)
|
||||
input = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
|
||||
quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
self.launch(input, residual, gamma, None, None, eps, True, quant_scale, None, False)
|
||||
# test one output
|
||||
self.launch(input, None, gamma, None, None, eps, False, quant_scale, None, False)
|
||||
|
||||
def test_rmsnorm_stride(self):
|
||||
C = 128
|
||||
print("test rmsnorm stride input...")
|
||||
for dtype in dtype_list:
|
||||
gamma = torch.randn(C, device="mlu", dtype=dtype)
|
||||
inputs_shape = (8, 8, 10, C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input = inputs[:, :, 0:6, :]
|
||||
output = torch.empty(inputs.shape, device="mlu", dtype=dtype)
|
||||
output = output.as_strided(input.shape, input.stride())
|
||||
self.launch(input, None, gamma, None, None, eps, False, None, output, False)
|
||||
# test assumption get wrong stride when dim = 1
|
||||
input = inputs[:, :, 0 :1, :]
|
||||
input = input.as_strided(input.shape, (10240, 1280, 0, 1))
|
||||
self.launch(input, None, gamma, None, None, eps, False, None, input, False)
|
||||
inputs_shape = (8, 8, 10, 2*C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input1 = inputs[:, :, 0:2, 0:C]
|
||||
self.launch(input1, None, gamma, None, None, eps, False, None, input1, False)
|
||||
input2 = inputs[:, :, :, 0:C]
|
||||
self.launch(input2, None, gamma, None, None, eps, False, None, input2, False)
|
||||
#tset 3-dim input
|
||||
inputs_shape = (8, 8, 2*C)
|
||||
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
|
||||
input1 = inputs[:, 0:2, 0:C]
|
||||
self.launch(input1, None, gamma, None, None, eps, False, None, input1, False)
|
||||
input2 = inputs[0:2, 0:2, 0:C]
|
||||
self.launch(input2, None, gamma, None, None, eps, False, None, input2, False)
|
||||
|
||||
|
||||
# 防呆测试
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func1 = ops.fused_rms_norm
|
||||
batch, seq_len, hidden_size = 5, 12, 512
|
||||
input = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("input.dim() >= 2.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, None)
|
||||
|
||||
inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[..., ::2]
|
||||
self.assertException("input last dim must be contiguous.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, None)
|
||||
|
||||
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("rmsnorm mode need gamma.",
|
||||
func1, input, None, None, None, None, eps, False, None, None)
|
||||
gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
|
||||
self.assertException("rmsnorm mode, gamma size must be hidden_size.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, None)
|
||||
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[..., 6:9, :]
|
||||
self.assertException("quant_out is not support when input has stride.",
|
||||
func1, input, None, gamma, None, None, eps, False, quant_scale, None)
|
||||
inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
input = inputs[:, 6:9, ...]
|
||||
self.assertException("check the strides of input.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, input)
|
||||
input = inputs[..., 6:9, :]
|
||||
outputs = torch.randn(batch, 2 * seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
|
||||
output = outputs[:, 0:seq_len, 6:9, :]
|
||||
self.assertException("check the strides of output.",
|
||||
func1, input, None, gamma, None, None, eps, False, None, output)
|
||||
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq_len, hidden_size = 5, 12, 512
|
||||
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
|
||||
quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu')
|
||||
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
bias = torch.randn(hidden_size, dtype=torch.half, device='mlu')
|
||||
residual_out = torch.empty_like(input)
|
||||
|
||||
output = torch.zeros(input.shape, dtype=torch.int8, device='mlu')
|
||||
args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "rmsnorm", eps, True, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "rmsnorm", eps, False, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
args = (input, output, residual, gamma, beta, bias, None, None, None, "rmsnorm", eps, False, False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_unittest(TestFuseRmsNormOp)
|
||||
447
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rope.py
Normal file
447
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rope.py
Normal file
@@ -0,0 +1,447 @@
|
||||
import torch
|
||||
import unittest
|
||||
from torch.nn import functional as F
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
import random
|
||||
|
||||
def genSlotMapping(batch, block_size):
|
||||
output = []
|
||||
for i in range(batch):
|
||||
idx = random.randint(i * block_size, (i + 1) * block_size - 1)
|
||||
output.append(idx)
|
||||
|
||||
return output
|
||||
|
||||
def rotate(x: torch.Tensor):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def quant(input: torch.Tensor):
|
||||
input_fp32 = input.to(torch.float32)
|
||||
max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True)
|
||||
scale = max_value / 7
|
||||
scaled_input = torch.round(input_fp32 / scale)
|
||||
return scaled_input.to(torch.int8), scale[..., 0], input_fp32 / scale
|
||||
|
||||
class TestFusedRopeOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
qkv, k_cache_hp, v_cache_hp, sin_cache, cos_cache, position_id, gamma, beta, \
|
||||
k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, cache_bs_id_lp, \
|
||||
cache_seq_offsets_lp, k_scale_hp, v_scale_hp, k_scale_lp, v_scale_lp, slot_mapping_hp, \
|
||||
slot_mapping_lp, eps = args
|
||||
mixed_cache = k_cache_lp is not None and v_cache_lp is not None
|
||||
qkv = qkv.to(torch.float32)
|
||||
sin_cache = sin_cache.to(torch.float32)
|
||||
cos_cache = cos_cache.to(torch.float32)
|
||||
gamma = gamma.to(torch.float32)
|
||||
beta = beta.to(torch.float32)
|
||||
if k_scale_hp is not None:
|
||||
k_scale_hp = 1 / k_scale_hp
|
||||
if v_scale_hp is not None:
|
||||
v_scale_hp = 1 / v_scale_hp
|
||||
if not mixed_cache and k_scale_hp is None:
|
||||
k_cache_hp = k_cache_hp.to(torch.float32)
|
||||
v_cache_hp = v_cache_hp.to(torch.float32)
|
||||
discrete_batch_hp = cache_bs_id_hp is not None
|
||||
discrete_batch_lp = cache_bs_id_lp is not None
|
||||
paged_cache_hp = slot_mapping_hp is not None
|
||||
paged_cache_lp = slot_mapping_lp is not None
|
||||
head_size = qkv.shape[-1]
|
||||
rope_dim = head_size
|
||||
batch_size = qkv.shape[0]
|
||||
head_qkv = qkv.shape[-2]
|
||||
kv_heads = k_cache_hp.shape[1]
|
||||
q_heads = head_qkv - 2 * kv_heads
|
||||
head_qk = q_heads + kv_heads
|
||||
group_num = 1
|
||||
group_size = head_size
|
||||
if mixed_cache:
|
||||
group_num = k_scale_lp.shape[-1]
|
||||
group_size = int(head_size / group_num)
|
||||
k_cache_lp_shape = list(k_cache_lp.shape)
|
||||
k_cache_lp_shape[-1] *= 2
|
||||
k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_shape)
|
||||
v_cache_lp_shape = list(v_cache_lp.shape)
|
||||
v_cache_lp_shape.append(2)
|
||||
v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_shape)
|
||||
|
||||
qk = qkv[:, :, 0:q_heads + kv_heads].clone()
|
||||
|
||||
for i in range(batch_size):
|
||||
qk_i = qk[i]
|
||||
sin_cache_i = sin_cache[position_id[i]:position_id[i] + 1]
|
||||
cos_cache_i = cos_cache[position_id[i]:position_id[i] + 1]
|
||||
sin_cache_i = sin_cache_i[:1]
|
||||
cos_cache_i = cos_cache_i[:1]
|
||||
rot = rotate(qk_i)
|
||||
|
||||
qk_i[:] = rot * sin_cache_i.unsqueeze(1) + qk_i * cos_cache_i.unsqueeze(1)
|
||||
|
||||
q = qk[:, :, 0:q_heads]
|
||||
k = qk[:, :, q_heads:head_qk].contiguous().reshape(batch_size, kv_heads, head_size)
|
||||
|
||||
qkv_q = qkv[:, :, 0:q_heads]
|
||||
qkv_q[...] = q
|
||||
|
||||
shape_k = k.shape
|
||||
k = torch.reshape(k, (-1, shape_k[-1]))
|
||||
k_norm = F.layer_norm(k, (head_size,), gamma, beta, eps)
|
||||
k_norm = torch.reshape(k_norm, shape_k)
|
||||
|
||||
k_out = k_norm.reshape(batch_size, 1, kv_heads, head_size)
|
||||
v_out = qkv[:, :, head_qk:head_qkv].contiguous()
|
||||
k_out_hp = k_out.clone()
|
||||
v_out_hp = v_out.clone()
|
||||
k_out_lp = None
|
||||
v_out_lp = None
|
||||
if k_scale_hp is not None and v_scale_hp is not None:
|
||||
k_scale_hp = k_scale_hp.reshape(kv_heads, head_size)
|
||||
v_scale_hp = v_scale_hp.reshape(kv_heads, head_size)
|
||||
k_out_hp = (k_out * k_scale_hp).round().clamp(-128, 127).to(torch.int8)
|
||||
v_out_hp = (v_out * v_scale_hp).round().clamp(-128, 127).to(torch.int8)
|
||||
|
||||
if paged_cache_hp:
|
||||
block_size = k_cache_hp.shape[2]
|
||||
for i in range(batch_size):
|
||||
if slot_mapping_hp[i] >= 0:
|
||||
block_id = torch.div(slot_mapping_hp[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping_hp[i] % block_size
|
||||
k_cache_hp[block_id, :, block_offset, :] = k_out_hp[i]
|
||||
v_cache_hp[block_id, :, block_offset, :] = v_out_hp[i]
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
key_i = k_out_hp[i].transpose(1, 0)
|
||||
value_i = v_out_hp[i].transpose(1, 0)
|
||||
|
||||
cache_bs_id_hp_i = cache_bs_id_hp[i] if discrete_batch_hp else i
|
||||
cache_seqlen_offset_hp_i = cache_seq_offsets_hp[i]
|
||||
if cache_seqlen_offset_hp_i < 0 or cache_bs_id_hp_i < 0:
|
||||
continue
|
||||
|
||||
key_cache_hp_i = \
|
||||
k_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1]
|
||||
key_cache_hp_i[...] = key_i[...]
|
||||
|
||||
value_cache_hp_i = \
|
||||
v_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1]
|
||||
value_cache_hp_i[...] = value_i[...]
|
||||
|
||||
if mixed_cache:
|
||||
for i in range(batch_size):
|
||||
key_i = k_out[i].reshape(kv_heads, group_num, group_size)
|
||||
value_i = v_out[i].reshape(kv_heads, group_num, group_size)
|
||||
key_i_lp, key_scale_i_lp, _ = quant(key_i)
|
||||
value_i_lp, value_scale_i_lp, scaled_input = quant(value_i)
|
||||
|
||||
if paged_cache_lp:
|
||||
block_size = k_cache_lp_int8.shape[2]
|
||||
if slot_mapping_lp[i] >= 0:
|
||||
block_id = torch.div(slot_mapping_lp[i], block_size, rounding_mode='floor')
|
||||
block_offset_k = slot_mapping_lp[i] % block_size
|
||||
block_offset_v = torch.div(block_offset_k, 2, rounding_mode='floor')
|
||||
odd_even_offset = block_offset_k % 2
|
||||
k_cache_lp_int8[block_id, :, block_offset_k, :] = key_i_lp.reshape(kv_heads, head_size)
|
||||
v_cache_lp_int8[block_id, :, block_offset_v, :, odd_even_offset] = value_i_lp.reshape(kv_heads, head_size)
|
||||
k_scale_lp[block_id, :, block_offset_k, :] = key_scale_i_lp[...]
|
||||
v_scale_lp[block_id, :, block_offset_k, :] = value_scale_i_lp[...]
|
||||
else:
|
||||
cache_bs_id_lp_i = cache_bs_id_lp[i] if discrete_batch_lp else i
|
||||
cache_seqlen_offset_lp_i = cache_seq_offsets_lp[i]
|
||||
v_seq_offset = torch.div(cache_seqlen_offset_lp_i, 2, rounding_mode='floor')
|
||||
odd_even_offset = cache_seqlen_offset_lp_i % 2
|
||||
if cache_seqlen_offset_lp_i < 0 or cache_bs_id_lp_i < 0:
|
||||
continue
|
||||
key_cache_lp_i = \
|
||||
k_cache_lp_int8[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
|
||||
key_cache_lp_i[...] = key_i_lp.reshape(kv_heads, head_size)
|
||||
key_scale_lp_i = \
|
||||
k_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
|
||||
key_scale_lp_i[...] = key_scale_i_lp[...]
|
||||
|
||||
value_cache_lp_i = \
|
||||
v_cache_lp_int8[cache_bs_id_lp_i, :, v_seq_offset, :, odd_even_offset]
|
||||
value_cache_lp_i[...] = value_i_lp.reshape(kv_heads, head_size)
|
||||
value_scale_lp_i = \
|
||||
v_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
|
||||
value_scale_lp_i[...] = value_scale_i_lp[...]
|
||||
out = (qkv, k_cache_hp, v_cache_hp)
|
||||
if mixed_cache:
|
||||
k_cache_lp = PairlyPackInt8(k_cache_lp_int8.view(-1, head_size)).reshape(k_cache_lp.shape)
|
||||
v_cache_lp_int8 = v_cache_lp_int8.transpose(2, 3)
|
||||
s0,s1,s2,s3,s4 = v_cache_lp_int8.shape
|
||||
v_cache_lp = PairlyPackInt8(v_cache_lp_int8.reshape(-1, s3 * s4)).reshape(s0, s1, s2, s3 * s4 // 2).transpose(2,3)
|
||||
out += (k_cache_lp, v_cache_lp)
|
||||
if k_scale_lp is not None:
|
||||
out += (k_scale_lp, v_scale_lp)
|
||||
return out
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device")
|
||||
def test_fused_rope(self):
|
||||
random_cases = 100
|
||||
random.seed(355)
|
||||
eps = 1e-6
|
||||
for _ in range(random_cases):
|
||||
if random_cases == 1:
|
||||
bs = 512
|
||||
seq_len = 1
|
||||
q_heads = 8
|
||||
kv_heads = 1
|
||||
head_size = 128
|
||||
rope_dim = 128
|
||||
dtype = torch.float16
|
||||
need_quant_kv = False
|
||||
mixed_cache = False
|
||||
max_decode_len_hp = 128
|
||||
max_decode_len_lp = 128
|
||||
discrete_batch_hp = True
|
||||
discrete_batch_lp = False
|
||||
paged_cache_hp = False
|
||||
paged_cache_lp = True
|
||||
block_size_hp = 16
|
||||
block_size_lp = 128
|
||||
num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp)
|
||||
num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp)
|
||||
group_num = 1
|
||||
else:
|
||||
bs = random.randint(1, 512)
|
||||
seq_len = 1
|
||||
q_heads = random.randint(1, 32)
|
||||
kv_heads = random.randint(1, 32)
|
||||
head_size_list = [32, 64, 96, 128]
|
||||
head_size = random.choice(head_size_list)
|
||||
rope_dim = head_size
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
bool_list = [True, False]
|
||||
block_size_list = [16, 32, 64, 128]
|
||||
dtype = random.choice(dtype_list)
|
||||
need_quant_kv = random.choice(bool_list)
|
||||
mixed_cache = random.choice(bool_list)
|
||||
max_decode_len_hp = random.randint(128, 1024)
|
||||
max_decode_len_lp = random.randint(128, 1024)
|
||||
discrete_batch_hp = random.choice(bool_list)
|
||||
discrete_batch_lp = random.choice(bool_list)
|
||||
paged_cache_hp = random.choice(bool_list)
|
||||
paged_cache_lp = random.choice(bool_list)
|
||||
block_size_hp = random.choice(block_size_list)
|
||||
num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp)
|
||||
block_size_lp = random.choice(block_size_list)
|
||||
num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp)
|
||||
group_num_list = [1, 2, 4, 8]
|
||||
group_num = random.choice(group_num_list)
|
||||
|
||||
if mixed_cache:
|
||||
need_quant_kv = True
|
||||
group_size = int(head_size / group_num)
|
||||
if max_decode_len_lp % 2 != 0:
|
||||
max_decode_len_lp = max_decode_len_lp - 1
|
||||
print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, rope_dim: {}, "
|
||||
"dtype: {}, mixed_cache: {}, quant_kv: {}, paged_cache_hp: {}, paged_cache_lp: {}, "
|
||||
"discrete_batch_hp: {}, discrete_batch_lp: {}."
|
||||
.format(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype, mixed_cache,
|
||||
need_quant_kv, paged_cache_hp, paged_cache_lp, discrete_batch_hp, discrete_batch_lp))
|
||||
|
||||
if mixed_cache:
|
||||
print("max_decode_len_hp: {}, max_decode_len_lp: {}, num_blocks_hp: {}, num_blocks_lp: {}, "
|
||||
"block_size_hp: {}, block_size_lp: {}, group_num: {}, testing..."
|
||||
.format(max_decode_len_hp, max_decode_len_lp, num_blocks_hp, num_blocks_lp,
|
||||
block_size_hp, block_size_lp, group_num))
|
||||
|
||||
max_bs_hp = bs + 1 if discrete_batch_hp else bs
|
||||
max_bs_lp = bs + 1 if discrete_batch_lp else bs
|
||||
cache_size = 1 if need_quant_kv else 2
|
||||
cache_bytes_hp = num_blocks_hp * kv_heads * block_size_hp * head_size if paged_cache_hp else \
|
||||
max_bs_hp * kv_heads * max_decode_len_hp * head_size
|
||||
cache_bytes_hp = cache_bytes_hp * cache_size
|
||||
cache_bytes_lp = 0
|
||||
if mixed_cache:
|
||||
cache_bytes_lp = num_blocks_lp * kv_heads * block_size_lp * head_size if paged_cache_lp else \
|
||||
max_bs_lp * kv_heads * max_decode_len_lp * head_size
|
||||
|
||||
if cache_bytes_hp > 2**31 or cache_bytes_lp > 2**31:
|
||||
print("cache bytes can not be larger than int32max. ")
|
||||
continue
|
||||
|
||||
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
|
||||
input = torch.randn(size=input_shape, dtype=dtype).mlu()
|
||||
|
||||
if paged_cache_hp:
|
||||
cache_shape_hp = (num_blocks_hp, kv_heads, block_size_hp, head_size)
|
||||
else:
|
||||
cache_shape_hp = (max_bs_hp, kv_heads, max_decode_len_hp, head_size)
|
||||
if need_quant_kv:
|
||||
k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
|
||||
k_cache_hp = (k_cache_hp - 0.5) * 256
|
||||
k_cache_hp = k_cache_hp.to(torch.int8)
|
||||
v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
|
||||
v_cache_hp = (k_cache_hp - 0.5) * 256
|
||||
v_cache_hp = k_cache_hp.to(torch.int8)
|
||||
k_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01)
|
||||
v_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01)
|
||||
else:
|
||||
k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
|
||||
v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
|
||||
k_scale_ops_hp = None
|
||||
v_scale_ops_hp = None
|
||||
|
||||
k_cache_lp = None
|
||||
v_cache_lp = None
|
||||
k_scale_lp = None
|
||||
v_scale_lp = None
|
||||
if mixed_cache:
|
||||
if paged_cache_lp:
|
||||
k_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu()
|
||||
v_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu()
|
||||
cache_raw = torch.randn((num_blocks_lp * kv_heads * block_size_lp, head_size), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(cache_raw))
|
||||
cache_raw = cache_raw * (7 / max_value)
|
||||
cache_raw = cache_raw.to(torch.int8)
|
||||
k_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, block_size_lp, int(head_size / 2))
|
||||
cache_raw = torch.randn((int(num_blocks_lp * kv_heads * block_size_lp / 2), head_size * 2), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(cache_raw))
|
||||
cache_raw = cache_raw * (7 / max_value)
|
||||
cache_raw = cache_raw.to(torch.int8)
|
||||
v_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size)
|
||||
k_cache_lp_ref_shape = (num_blocks_lp, kv_heads, block_size_lp, head_size)
|
||||
v_cache_lp_ref_shape = (num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size, 2)
|
||||
else:
|
||||
k_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu()
|
||||
v_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu()
|
||||
cache_raw = torch.randn((max_bs_lp * kv_heads * max_decode_len_lp, head_size), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(cache_raw))
|
||||
cache_raw = cache_raw * (7 / max_value)
|
||||
cache_raw = cache_raw.to(torch.int8)
|
||||
k_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp, kv_heads, max_decode_len_lp, int(head_size / 2))
|
||||
cache_raw = torch.randn((int(max_bs_lp * kv_heads * max_decode_len_lp / 2), head_size * 2), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(cache_raw))
|
||||
cache_raw = cache_raw * (7 / max_value)
|
||||
cache_raw = cache_raw.to(torch.int8)
|
||||
v_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp,kv_heads, int(max_decode_len_lp / 2), head_size)
|
||||
k_cache_lp_ref_shape = (max_bs_lp, kv_heads, max_decode_len_lp, head_size)
|
||||
v_cache_lp_ref_shape = (max_bs_lp, kv_heads, int(max_decode_len_lp / 2), head_size, 2)
|
||||
del cache_raw
|
||||
|
||||
cache_bs_id_hp = None
|
||||
cache_bs_id_lp = None
|
||||
cache_seq_offsets_hp = None
|
||||
cache_seq_offsets_lp = None
|
||||
slot_mapping_hp = None
|
||||
slot_mapping_lp = None
|
||||
if not paged_cache_hp:
|
||||
if discrete_batch_hp:
|
||||
cache_bs_id_hp = random.sample([*range(0, max_bs_hp)], bs)
|
||||
cache_bs_id_hp = torch.IntTensor(cache_bs_id_hp).mlu()
|
||||
cache_seq_offsets_hp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_hp - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
else:
|
||||
slot_mapping_hp = random.sample([*range(-1, block_size_hp * num_blocks_hp)], bs)
|
||||
slot_mapping_hp = torch.IntTensor(slot_mapping_hp).mlu()
|
||||
|
||||
input_ref = input.clone()
|
||||
k_cache_hp_ref = k_cache_hp.clone()
|
||||
v_cache_hp_ref = v_cache_hp.clone()
|
||||
k_cache_lp_ref = None
|
||||
v_cache_lp_ref = None
|
||||
k_scale_lp_ref = None
|
||||
v_scale_lp_ref = None
|
||||
if mixed_cache:
|
||||
if not paged_cache_lp:
|
||||
if discrete_batch_lp:
|
||||
cache_bs_id_lp = random.sample([*range(0, max_bs_lp)], bs)
|
||||
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
|
||||
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
else:
|
||||
slot_mapping_lp = genSlotMapping(bs, block_size_lp)
|
||||
slot_mapping_lp = torch.IntTensor(slot_mapping_lp).mlu()
|
||||
k_cache_lp_ref = k_cache_lp.clone()
|
||||
v_cache_lp_ref = v_cache_lp.clone()
|
||||
k_scale_lp_ref = k_scale_lp.clone()
|
||||
v_scale_lp_ref = v_scale_lp.clone()
|
||||
|
||||
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
|
||||
|
||||
notify_start = torch.mlu.Event(enable_timing=True)
|
||||
notify_end = torch.mlu.Event(enable_timing=True)
|
||||
notify_start.record()
|
||||
base_output = self.op_impl_base(input_ref, k_cache_hp_ref, v_cache_hp_ref, sin_table, \
|
||||
cos_table, position_id, gamma, beta, \
|
||||
k_cache_lp_ref, v_cache_lp_ref, cache_bs_id_hp, cache_seq_offsets_hp, \
|
||||
cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \
|
||||
k_scale_lp_ref, v_scale_lp_ref, slot_mapping_hp, slot_mapping_lp, eps)
|
||||
notify_end.record()
|
||||
notify_end.synchronize()
|
||||
time = notify_start.hardware_time(notify_end)
|
||||
print("baseline hw_time is: ", time, "us")
|
||||
del input_ref, k_cache_hp_ref, v_cache_hp_ref, k_cache_lp_ref, v_cache_lp_ref, k_scale_lp_ref, v_scale_lp_ref
|
||||
|
||||
notify_start.record()
|
||||
loop = 1
|
||||
for _ in range(loop):
|
||||
ops.fused_rope(input, k_cache_hp, v_cache_hp, sin_table, cos_table, position_id, \
|
||||
gamma, beta, k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, \
|
||||
cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \
|
||||
k_scale_lp, v_scale_lp, slot_mapping_hp, slot_mapping_lp, eps)
|
||||
|
||||
notify_end.record()
|
||||
notify_end.synchronize()
|
||||
time = notify_start.hardware_time(notify_end) / loop
|
||||
print("hw time is: ", time, "us")
|
||||
|
||||
print("check input diff \n")
|
||||
self.assertTensorsEqual(input.cpu().float(), base_output[0].cpu().float(), 0.003, use_MSE=True)
|
||||
print("pass \n")
|
||||
print("check key cache hp diff \n")
|
||||
self.assertTensorsEqual(k_cache_hp.cpu().float(), base_output[1].cpu().float(), 0.003, use_MSE=True)
|
||||
print("pass \n")
|
||||
print("check value cache hp diff \n")
|
||||
self.assertTensorsEqual(v_cache_hp.cpu().float(), base_output[2].cpu().float(), 0.003, use_MSE=True)
|
||||
print("pass \n")
|
||||
if mixed_cache:
|
||||
k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_ref_shape)
|
||||
v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_ref_shape)
|
||||
k_cache_lp_ref_int8 = UnpackInt4(base_output[3]).reshape(k_cache_lp_ref_shape)
|
||||
v_cache_lp_ref_int8 = UnpackInt4(base_output[4]).reshape(v_cache_lp_ref_shape)
|
||||
|
||||
print("check key cache lp diff \n")
|
||||
self.assertTensorsEqual(k_cache_lp_int8.cpu().float(), k_cache_lp_ref_int8.cpu().float(), 1)
|
||||
print("pass \n")
|
||||
print("check key scale lp diff \n")
|
||||
self.assertTensorsEqual(k_scale_lp.cpu().float(), base_output[5].cpu().float(), 0.003, use_MSE=True)
|
||||
print("pass \n")
|
||||
print("check value cache lp diff \n")
|
||||
self.assertTensorsEqual(v_cache_lp_int8.cpu().float(), v_cache_lp_ref_int8.cpu().float(), 1)
|
||||
print("pass \n")
|
||||
print("check value scale lp diff \n")
|
||||
self.assertTensorsEqual(v_scale_lp.cpu().float(), base_output[6].cpu().float(), 0.003, use_MSE=True)
|
||||
print("pass \n")
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device")
|
||||
def test_inductor(self):
|
||||
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, max_decode_len, dtype= 40, 1, 8, 1, 128, 128, 2048, torch.bfloat16
|
||||
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
|
||||
input = torch.randn(size=input_shape, dtype=dtype).mlu()
|
||||
max_bs = bs + 1
|
||||
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
|
||||
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
|
||||
k_cache = cache[0]
|
||||
v_cache = cache[1]
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], bs)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
|
||||
args = (input, k_cache, v_cache, None, None, sin_table, cos_table, position_id, gamma, beta, None,
|
||||
None, None, None, cache_bs_id, cache_seq_offsets, None, None, None, None, 1e-5)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_rope, args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_unittest(TestFusedRopeOp))
|
||||
122
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_group_gemm.py
Executable file
122
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_group_gemm.py
Executable file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
from torch_mlu import mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
from torch.nn import functional as F
|
||||
|
||||
def gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False):
|
||||
bs = batch * seq
|
||||
token_topk = bs * topk
|
||||
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
|
||||
sorted_expert_id, indices = expert_id.sort()
|
||||
gather_idx = indices // topk
|
||||
gather_idx = gather_idx.to(torch.int32)
|
||||
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
|
||||
|
||||
a = torch.randn(bs, k, device="mlu", dtype=data_type)
|
||||
if not idx_mode:
|
||||
a = a[gather_idx]
|
||||
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
|
||||
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
|
||||
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
|
||||
a_scale = None
|
||||
b_scale = None
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
|
||||
|
||||
gather_idx_ = gather_idx if idx_mode else None
|
||||
return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bias
|
||||
|
||||
|
||||
class TestGroupGemmOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
a = create_tensor_from_dic(dic['a'])
|
||||
b = create_tensor_from_dic(dic['b'])
|
||||
m_list = dic['m_list']['data']
|
||||
expand_idx = dic['expand_idx']['data']
|
||||
c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c'])
|
||||
alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha'])
|
||||
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
|
||||
max_m = dic['max_m']['data']
|
||||
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
|
||||
self.launch(a, b, m_list, expand_idx, c, alpha, beta, max_m, bias)
|
||||
|
||||
def launch(self, *args):
|
||||
total_m = args[2].sum().item()
|
||||
torch_out = self.op_impl_base(*args)
|
||||
tmo_out = ops.group_gemm(*args)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
a, b, m_list, expand_idx, c, alpha, beta, max_m, bias = args
|
||||
a = a.reshape(-1, a.size(-1))
|
||||
if expand_idx is not None:
|
||||
a = a[expand_idx]
|
||||
total_m = m_list.sum().item()
|
||||
a_list = a[:total_m].split(tuple(m_list))
|
||||
|
||||
c_list = []
|
||||
if c is not None:
|
||||
c = c.reshape(-1, c.size(-1))
|
||||
c_list = c[:total_m].split(tuple(m_list))
|
||||
|
||||
output_list = []
|
||||
for i in range(b.size(0)): # alpha*(a*b) + bias + beta*c
|
||||
if (a_list[i].size(0) > 0):
|
||||
gemm_out = torch.matmul(a_list[i], b[i].permute(1, 0))
|
||||
if alpha is not None:
|
||||
gemm_out *= alpha[i]
|
||||
if bias is not None:
|
||||
gemm_out += bias[i]
|
||||
if beta is not None and c_list != []:
|
||||
gemm_out += c_list[i] * beta[i]
|
||||
output_list.append(gemm_out)
|
||||
real_res = torch.cat(output_list, dim=0)
|
||||
output = torch.empty(a.shape[0], b.shape[1], device=real_res.device).to(real_res.dtype)
|
||||
output[:total_m] = real_res
|
||||
return output
|
||||
|
||||
def test_group_gemm(self):
|
||||
bs_list = [1, 3]
|
||||
seq_list = [5, 8]
|
||||
k_list = [512, 1024]
|
||||
n_list = [512, 768, 2048]
|
||||
expert_list = [8, 32]
|
||||
topk_list = [2, 5]
|
||||
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
|
||||
has_bias_list = [True, False]
|
||||
|
||||
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list)
|
||||
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args:
|
||||
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \
|
||||
dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True)
|
||||
param = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
|
||||
torch_out = self.op_impl_base(*param[:7], batch * seq, param[10])
|
||||
tmo_out = ops.group_gemm(*param[:7], batch * seq, param[10])
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
|
||||
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
|
||||
has_bias_list = [True, False]
|
||||
args = product( dtype_list, idx_list, has_bias_list)
|
||||
for data_type, idx_mode, has_bias in args:
|
||||
args = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
|
||||
args = list(args)
|
||||
args[-2] = args[-1] # bias
|
||||
args[-1] = None #dtype
|
||||
args.extend([None, None, batch * seq]) #b_offset, max_m
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestGroupGemmOp))
|
||||
193
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py
Executable file
193
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py
Executable file
@@ -0,0 +1,193 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
|
||||
class TestMatMulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \
|
||||
a_scale, b_scale, trans_a, trans_b = args
|
||||
if a_scale is not None:
|
||||
a = a / a_scale
|
||||
if b_scale is not None:
|
||||
b = b / b_scale
|
||||
if trans_a:
|
||||
a = a.transpose(0, 1)
|
||||
if trans_b:
|
||||
b = b.transpose(0, 1)
|
||||
mul_out = alpha * torch.matmul(a, b)
|
||||
if bias is not None:
|
||||
mul_out += bias
|
||||
if c is not None:
|
||||
mul_out += beta * c
|
||||
if act_mode in act_mode_dict.keys():
|
||||
active = act_mode_dict[act_mode]
|
||||
mul_out = active(mul_out.float()).to(a.dtype)
|
||||
return mul_out
|
||||
|
||||
def test_matmul(self):
|
||||
mat_m_list = [32]
|
||||
mat_n_list = [256]
|
||||
mat_k_list = [128]
|
||||
has_res_list = [ False, True]
|
||||
has_bias_list = [True, False]
|
||||
trans_a_list = [False, True]
|
||||
trans_b_list = [False, True]
|
||||
act_mode_list = ['none', 'relu', 'gelu', 'silu']
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
alpha = 0.625
|
||||
beta = 1.0
|
||||
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
|
||||
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
|
||||
torch.manual_seed(1)
|
||||
print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
|
||||
if has_res :
|
||||
beta = 1.0
|
||||
else :
|
||||
beta = 0.
|
||||
|
||||
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
||||
if trans_a:
|
||||
shape_a = (mat_k, 4, mat_m)
|
||||
if trans_b:
|
||||
shape_b = (mat_n, 3, mat_k)
|
||||
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
||||
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
||||
input = input0[:, 1, :]
|
||||
weight = weight0[:, 0, :]
|
||||
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
||||
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
|
||||
output = self.op_impl_base(input,
|
||||
weight,
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
||||
tmo_output = ops.matmul(input,
|
||||
weight,
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
||||
tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(),
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
||||
|
||||
if act_mode == 'gelu':
|
||||
tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(),
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b)
|
||||
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
# @unittest.skip("not test")
|
||||
def test_matmul_int8(self):
|
||||
mat_m_list = [32]
|
||||
mat_n_list = [256]
|
||||
mat_k_list = [128]
|
||||
has_res_list = [True, False]
|
||||
has_bias_list = [True, False]
|
||||
trans_a_list = [True, False]
|
||||
trans_b_list = [True, False]
|
||||
act_mode_list = ['none', 'relu', 'silu', 'gelu']
|
||||
dtype_list = [torch.half, torch.float]
|
||||
alpha = 0.625
|
||||
beta = 1.0
|
||||
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
|
||||
|
||||
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
|
||||
print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
|
||||
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
|
||||
torch.manual_seed(1)
|
||||
if has_res :
|
||||
beta = 1.0
|
||||
else :
|
||||
beta = 0.
|
||||
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
||||
if trans_a:
|
||||
shape_a = (mat_k, 4, mat_m)
|
||||
if trans_b:
|
||||
shape_b = (mat_n, 3, mat_k)
|
||||
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
||||
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
||||
input = input0[:, 1, :]
|
||||
weight = weight0[:, 0, :]
|
||||
input8, a_scale = QuantByTensor(input, 8)
|
||||
weight8, b_scale = QuantByTensor(weight, 8)
|
||||
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
||||
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
|
||||
output = self.op_impl_base(input8, weight8,
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b)
|
||||
tmo_output = ops.matmul(input8, weight8,
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
|
||||
tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(),
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
|
||||
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
if act_mode == 'gelu':
|
||||
tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(),
|
||||
alpha * bias if has_bias else None,
|
||||
residual if has_res else None,
|
||||
act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b)
|
||||
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
|
||||
trans_a_list = [True, False]
|
||||
trans_b_list = [True, False]
|
||||
dtype_list = [torch.half, torch.float]
|
||||
args = product(trans_a_list, trans_b_list, dtype_list)
|
||||
for trans_a, trans_b, dtype in args:
|
||||
print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype))
|
||||
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
||||
if trans_a:
|
||||
shape_a = (mat_k, 4, mat_m)
|
||||
if trans_b:
|
||||
shape_b = (mat_n, 3, mat_k)
|
||||
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
||||
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
||||
a = input0[:, 1, :]
|
||||
b = weight0[:, 0, :]
|
||||
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
||||
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
||||
args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
|
||||
|
||||
a8, a_scale = QuantByTensor(a, 8)
|
||||
b8, b_scale = QuantByTensor(b, 8)
|
||||
str_dtype = "half"
|
||||
if dtype == torch.float:
|
||||
str_dtype = "float"
|
||||
elif dtype == torch.bfloat16:
|
||||
str_dtype = "bfloat16"
|
||||
args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestMatMulOp))
|
||||
946
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py
Executable file
946
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py
Executable file
@@ -0,0 +1,946 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from typing import Union, List, Tuple
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
from torch_mlu import mlu
|
||||
|
||||
class TestFusedMOEOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
hidden_states = create_tensor_from_dic(dic['hidden_states'])
|
||||
gating_output = create_tensor_from_dic(dic['gating_output'])
|
||||
w1 = create_tensor_from_dic(dic['w1'])
|
||||
w2 = create_tensor_from_dic(dic['w2'])
|
||||
if dic['w1']['dtype'] is not torch.int8:
|
||||
w1 *= 0.1
|
||||
w2 *= 0.1
|
||||
bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1'])
|
||||
bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2'])
|
||||
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
|
||||
input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05)
|
||||
act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05)
|
||||
w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05)
|
||||
w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05)
|
||||
topk = dic['topk']['data']
|
||||
renormalize = dic['renormalize']['data']
|
||||
gated = dic['gated']['data']
|
||||
act_mode = dic['act_mode']['data']
|
||||
start_expert_id = dic['start_expert_id']['data']
|
||||
block_n = dic['block_n']['data']
|
||||
cncl_comm = dic['cncl_comm']['data']
|
||||
w1_quant_flag = dic['w1_quant_flag']['data']
|
||||
w2_quant_flag = dic['w2_quant_flag']['data']
|
||||
self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth,
|
||||
act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode,
|
||||
start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag)
|
||||
|
||||
def launch(self, *args):
|
||||
base = self.op_impl_base(*args)
|
||||
tmo_res = tmo.fused_moe(*args)
|
||||
self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \
|
||||
act_smooth, w1_scale, w2_scale, topk, renormalize, \
|
||||
gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args
|
||||
if w2.dim() == 4:
|
||||
w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1))
|
||||
expert_num = gating_output.size(-1)
|
||||
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
|
||||
def router(hidden_states, router_logit):
|
||||
router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
|
||||
topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1)
|
||||
if renormalize:
|
||||
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
|
||||
sorted_expert_id, indices = expert_id.int().flatten().sort()
|
||||
nk = indices.size(0)
|
||||
token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu()
|
||||
expand_idx = indices.int() // topk
|
||||
combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu")
|
||||
combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu"))
|
||||
input_expand = hidden_states[expand_idx]
|
||||
input_list = input_expand.split(tuple(token_cout))
|
||||
|
||||
input_scale_list = []
|
||||
input_split_result = []
|
||||
if input_smooth is not None:
|
||||
# do smooth quant on input
|
||||
idx = 0
|
||||
for i in range(expert_num):
|
||||
if i >= start_expert_id and i < start_expert_id + expert_size:
|
||||
if (input_list[i].size(0) > 0):
|
||||
temp = input_list[i] * input_smooth[idx]
|
||||
input_split_result.append(temp)
|
||||
idx += 1
|
||||
else:
|
||||
input_split_result.append(input_list[i])
|
||||
quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8)
|
||||
input_list = quant_input.split(token_cout.tolist())
|
||||
input_scale_list = input_scale.split(token_cout.tolist())
|
||||
return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
if w1_quant_flag is None:
|
||||
inner_size = w1.size(1) // 2 if gated else w1.size(1)
|
||||
else:
|
||||
inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2)
|
||||
hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2)
|
||||
input = hidden_states.view(-1, hidden_states.size(-1))
|
||||
gating_output = gating_output.view(-1, gating_output.size(-1))
|
||||
input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output)
|
||||
output_list = []
|
||||
idx = 0
|
||||
need_quant = len(input_scale_list) != 0
|
||||
if need_quant and w1_scale.dim() == 3:
|
||||
w1_scale = w1_scale.transpose(0, 1).contiguous()
|
||||
w2_scale = w2_scale.transpose(0, 1).contiguous()
|
||||
if w1_quant_flag is not None:
|
||||
w1_quant_group = w1_scale.size(1)
|
||||
quant_wise = hidden_size // w1_quant_group
|
||||
w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group)
|
||||
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
|
||||
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
|
||||
if w2_quant_flag is not None:
|
||||
w2_quant_group = w2_scale.size(1)
|
||||
quant_wise = inner_size // w2_quant_group
|
||||
w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group)
|
||||
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
|
||||
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
|
||||
for i in range(expert_num): # start_expert_id : start_expert_id + expert_size
|
||||
if i >= start_expert_id and i < start_expert_id + expert_size:
|
||||
if (input_list[i].size(0) > 0):
|
||||
if need_quant:
|
||||
if w1_quant_flag is None:
|
||||
gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i],
|
||||
w1[idx], w1_scale[idx], dtype)
|
||||
else:
|
||||
gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i],
|
||||
w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype,
|
||||
quant_flag = w1_quant_flag[idx])
|
||||
else:
|
||||
gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0))
|
||||
act_in = gemm1_out[:, :inner_size].float()
|
||||
gate = gemm1_out[:, inner_size:]
|
||||
act = act_mode_dict[act_mode]
|
||||
gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \
|
||||
act(act_in).to(dtype=dtype) * gate
|
||||
if need_quant:
|
||||
quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8)
|
||||
if w2_quant_flag is None:
|
||||
gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype)
|
||||
else:
|
||||
gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale,
|
||||
w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype,
|
||||
quant_flag = w2_quant_flag[idx])
|
||||
else:
|
||||
gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0))
|
||||
output_list.append(gemm2_out)
|
||||
idx += 1
|
||||
else:
|
||||
output_list.append(torch.zeros_like(input_list[i]))
|
||||
output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight
|
||||
output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype)
|
||||
if residual is not None:
|
||||
output = output + residual.view(input.shape)
|
||||
return output.view(hidden_states.shape)
|
||||
|
||||
def test_fused_moe(self):
|
||||
print("test_fused_moe")
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 512, 768
|
||||
dtype_list = [torch.half]
|
||||
if mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
||||
|
||||
start_expert_id_list = [0, 1, 3, 4, 5, 6]
|
||||
expert_size_list = [8, 4, 3, 2, 3, 2]
|
||||
for i in range(len(start_expert_id_list)):
|
||||
start_expert_id = start_expert_id_list[i]
|
||||
expert_size = expert_size_list[i]
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
torch_out = self.op_impl_base(hidden_states,
|
||||
router_logit,
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
0, 0, None, None)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(hidden_states,
|
||||
router_logit,
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
# (N*T, C)
|
||||
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id).reshape(batch, seq, hidden_size)
|
||||
tmo_out_3 = fused_moe(hidden_states,
|
||||
router_logit,
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_pertoken_quant_fused_moe_tp(self):
|
||||
print("test_pertoken_quant_fused_moe_tp")
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
|
||||
|
||||
def test_moe_tp2_mixed_ep4_no_quant(self):
|
||||
print("test_moe_tp2_mixed_ep4_no_quant")
|
||||
tp_num = 2
|
||||
ep_num = 4
|
||||
expert_num = 32
|
||||
expert_size = expert_num // ep_num
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = None
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
||||
w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size)
|
||||
w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num)
|
||||
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
tmo_out1 = torch.zeros_like(hidden_states)
|
||||
tmo_out2 = torch.zeros_like(hidden_states)
|
||||
torch_out = torch.zeros_like(hidden_states)
|
||||
for tp_idx in range(tp_num):
|
||||
w1_curr_tp = w1[:, tp_idx, ...]
|
||||
w2_curr_tp = w2[:, :, tp_idx, :]
|
||||
for ep_idx in range(ep_num):
|
||||
start_expert_id = ep_idx * expert_size
|
||||
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
|
||||
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
|
||||
tmo_out1 += ops.fused_moe(hidden_states,
|
||||
router_logit,
|
||||
w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep,
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
tmo_out2 += fused_moe(hidden_states,
|
||||
router_logit,
|
||||
w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep,
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
torch_out += self.op_impl_base(hidden_states,
|
||||
router_logit,
|
||||
w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep,
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
0, 0, None, None)
|
||||
tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size)
|
||||
|
||||
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
|
||||
def test_moe_tp2_mixed_ep4_quant(self):
|
||||
print("test_moe_tp2_mixed_ep4_quant")
|
||||
tp_num = 2
|
||||
ep_num = 4
|
||||
expert_num = 32
|
||||
expert_size = expert_num // ep_num
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
scale_s = 0.1 # avoid the occurrence of inf
|
||||
eps = 0.01 # Avoid the occurrence of nan
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = None
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
||||
|
||||
input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
torch_out = torch.zeros_like(hidden_states)
|
||||
for _ in range(tp_num):
|
||||
w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
|
||||
act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
|
||||
|
||||
w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size)
|
||||
w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num)
|
||||
input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size)
|
||||
act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
|
||||
w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
|
||||
w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
tmo_out = torch.zeros_like(hidden_states)
|
||||
tmo_out1 = torch.zeros_like(hidden_states)
|
||||
torch_out = torch.zeros_like(hidden_states)
|
||||
for tp_idx in range(tp_num):
|
||||
w1_curr_tp = w1[:, :, tp_idx, ...]
|
||||
w2_curr_tp = w2[:, :, :, tp_idx, :]
|
||||
act_smooth_tp = act_smooth[:, :, tp_idx, :]
|
||||
w1_scale_tp = w1_scale[:, :, tp_idx, :]
|
||||
for ep_idx in range(ep_num):
|
||||
start_expert_id = ep_idx * expert_size
|
||||
w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous()
|
||||
w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous()
|
||||
input_smooth_curr_ep = input_smooth[ep_idx].contiguous()
|
||||
act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous()
|
||||
w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous()
|
||||
w2_scale_curr_ep = w2_scale[ep_idx].contiguous()
|
||||
tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size]
|
||||
router_logit, # [batch, seq, expert_num]
|
||||
w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size]
|
||||
w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num]
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth_curr_ep, # [expert_size, hidden_size]
|
||||
act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
|
||||
w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
|
||||
w2_scale_curr_ep, # [expert_size, hidden_size]
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
tmo_out1 += fused_moe(hidden_states,
|
||||
router_logit,
|
||||
w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep,
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth_curr_ep,
|
||||
act_smooth_curr_ep,
|
||||
w1_scale_curr_ep,
|
||||
w2_scale_curr_ep,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
torch_out += self.op_impl_base(hidden_states,
|
||||
router_logit,
|
||||
w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep,
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth_curr_ep,
|
||||
act_smooth_curr_ep,
|
||||
w1_scale_curr_ep,
|
||||
w2_scale_curr_ep,
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
0, 0, None, None)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
|
||||
|
||||
def test_smq_fused_moe_random_tp(self):
|
||||
print("test_smq_fused_moe_random_tp")
|
||||
import random
|
||||
random.seed(0)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
for i in range(10):
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 2048, 2)
|
||||
inner_size = random.randrange(512, 2048, 2)
|
||||
expert_num = random.randint(1, 40)
|
||||
topk = random.randint(1,expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
data_type = random.choice([torch.bfloat16, torch.float16])
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
self.__run_sq_case(*case)
|
||||
|
||||
def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1):
|
||||
if expert_size == -1:
|
||||
expert_size = expert_num
|
||||
scale_s = 0.01 # avoid the occurrence of inf
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
|
||||
bias1, bias2 = None, None
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
torch_out = self.op_impl_base(hidden_states,
|
||||
router_logit,
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
w1_scale[start_expert_id:start_expert_id+expert_size],
|
||||
w2_scale[start_expert_id:start_expert_id+expert_size],
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
0, 0, None, None)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(hidden_states,
|
||||
router_logit,
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
w1_scale[start_expert_id:start_expert_id+expert_size],
|
||||
w2_scale[start_expert_id:start_expert_id+expert_size],
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
# # (N*T, C)
|
||||
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
w1_scale[start_expert_id:start_expert_id+expert_size],
|
||||
w2_scale[start_expert_id:start_expert_id+expert_size],
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id).view(batch, seq, hidden_size)
|
||||
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
w1_scale[start_expert_id:start_expert_id+expert_size],
|
||||
w2_scale[start_expert_id:start_expert_id+expert_size],
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id).view(batch, seq, hidden_size)
|
||||
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
||||
|
||||
def test_fused_moe_with_4D_w2(self):
|
||||
print("test_fused_moe_with_4D_w2")
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
dtype_list = [torch.half]
|
||||
if mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype
|
||||
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
||||
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
||||
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
||||
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
||||
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
torch_out = self.op_impl_base(hidden_states, router_logit,
|
||||
weight1, weight2,
|
||||
bias1, bias2, residual, None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(hidden_states, router_logit,
|
||||
weight1, weight2,
|
||||
bias1, bias2, residual, None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, 0)
|
||||
# (N*T, C)
|
||||
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size)
|
||||
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_moe_tp2_mixed_ep4_with_4D_w2(self):
|
||||
print("test_moe_tp2_mixed_ep4_with_4D_w2")
|
||||
tp_num, ep_num = 2, 4
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
||||
assert inner_size % tp_num == 0
|
||||
assert expert_num % ep_num == 0
|
||||
expert_size = expert_num // ep_num
|
||||
assert 4096 % inner_size == 0
|
||||
block_e = 4096 // inner_size
|
||||
assert expert_size % block_e == 0
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
||||
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
||||
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
||||
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
||||
residual = None
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
|
||||
w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size)
|
||||
w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1)
|
||||
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
|
||||
tmo_out = torch.zeros_like(hidden_states)
|
||||
torch_out = torch.zeros_like(hidden_states)
|
||||
for tp_idx in range(tp_num):
|
||||
new_inner_size = w2.shape[-1]
|
||||
w1_curr_tp = w1[:, tp_idx, ...]
|
||||
w2_curr_tp = w2[:, :, tp_idx, :]
|
||||
for ep_idx in range(ep_num):
|
||||
start_expert_id = ep_idx * expert_size
|
||||
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
|
||||
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
|
||||
torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual, None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None)
|
||||
tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep,
|
||||
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2, residual, None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, start_expert_id)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_sq_fused_moe_with_4D_w2(self):
|
||||
print("test_sq_fused_moe_with_4D_w2")
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
assert 4096 % inner_size == 0
|
||||
block_e = 4096 // inner_size
|
||||
assert expert_num % block_e == 0
|
||||
scale_s = 0.01 # avoid the occurrence of inf
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
|
||||
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
|
||||
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
|
||||
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
|
||||
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
bias1, bias2 = None, None
|
||||
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2,
|
||||
bias1, bias2, residual, input_smooth, act_smooth,
|
||||
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2,
|
||||
bias1, bias2, residual, input_smooth, act_smooth,
|
||||
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0)
|
||||
# # (N*T, C)
|
||||
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
quant_w1,
|
||||
quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(),
|
||||
bias1, bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
input_smooth, act_smooth,
|
||||
w1_scale, w2_scale, topk, renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
0).view(batch, seq, hidden_size)
|
||||
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
|
||||
|
||||
def test_sq_fused_moe_random_tp_quant_grouped(self):
|
||||
print("test_sq_fused_moe_random_tp_quant_grouped")
|
||||
import random
|
||||
random.seed(0)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while (len(case_list) < 100):
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(512, 2048, 128)
|
||||
inner_size = random.randrange(512, 2048, 512)
|
||||
expert_num = random.randint(1, 40)
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
quant_bit = random.choice([4, 8])
|
||||
data_type = random.choice([torch.bfloat16, torch.float16])
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
|
||||
self.__run_quant_grouped_case(*case)
|
||||
|
||||
def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128):
|
||||
def get_quant_group(n, quant_group_size):
|
||||
quant_group = n // quant_group_size
|
||||
return quant_group if quant_group >= 1 else 1
|
||||
|
||||
if expert_size == -1:
|
||||
expert_size = expert_num
|
||||
scale_s = 0.01 # avoid the occurrence of inf
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
|
||||
bias1, bias2 = None, None
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
w1_quant_group = get_quant_group(hidden_size, quant_group_size)
|
||||
w2_quant_group = get_quant_group(inner_size, quant_group_size)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
if quant_bit == 4:
|
||||
quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2)
|
||||
w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous()
|
||||
w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous()
|
||||
# split scale and transpose
|
||||
def extract_scale(scale, start_expert, expert_size):
|
||||
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
|
||||
|
||||
torch_out = self.op_impl_base(hidden_states,
|
||||
router_logit,
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
extract_scale(w1_scale, start_expert_id, expert_size),
|
||||
extract_scale(w2_scale, start_expert_id, expert_size),
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
0, 0, None, None)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(hidden_states,
|
||||
router_logit,
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
extract_scale(w1_scale, start_expert_id, expert_size),
|
||||
extract_scale(w2_scale, start_expert_id, expert_size),
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id)
|
||||
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
|
||||
router_logit.view(-1, expert_num),
|
||||
quant_w1[start_expert_id:start_expert_id+expert_size],
|
||||
quant_w2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual.view(-1, hidden_size) if residual is not None else None,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
extract_scale(w1_scale, start_expert_id, expert_size),
|
||||
extract_scale(w2_scale, start_expert_id, expert_size),
|
||||
topk,
|
||||
renormalize,
|
||||
gated,
|
||||
act_mode,
|
||||
start_expert_id).view(batch, seq, hidden_size)
|
||||
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
|
||||
|
||||
def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1):
|
||||
if expert_size == -1:
|
||||
expert_size = expert_num
|
||||
w1_quant_group = hidden_size // quant_wise
|
||||
w2_quant_group = inner_size // quant_wise
|
||||
w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4
|
||||
w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4
|
||||
w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated)
|
||||
w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size
|
||||
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
|
||||
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
||||
act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
|
||||
bias1, bias2 = None, None
|
||||
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
|
||||
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
|
||||
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
|
||||
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
|
||||
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
|
||||
|
||||
# split scale and transpose
|
||||
def extract_scale(scale, start_expert, expert_size):
|
||||
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
|
||||
|
||||
params = [hidden_states, router_logit,
|
||||
w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]],
|
||||
w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]],
|
||||
bias1, bias2, residual,
|
||||
input_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
act_smooth[start_expert_id:start_expert_id+expert_size],
|
||||
extract_scale(w1_scale, start_expert_id, expert_size),
|
||||
extract_scale(w2_scale, start_expert_id, expert_size),
|
||||
topk, renormalize, gated, act_mode, start_expert_id, 0, 0,
|
||||
w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(),
|
||||
w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()]
|
||||
torch_out = self.op_impl_base(*params)
|
||||
# (N, T, C)
|
||||
tmo_out_1 = ops.fused_moe(*params)
|
||||
tmo_out_2 = fused_moe(*params)
|
||||
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
|
||||
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
|
||||
|
||||
def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self):
|
||||
print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed")
|
||||
import random
|
||||
random.seed(0)
|
||||
act_mode = 'gelu'
|
||||
case_list = set()
|
||||
while (len(case_list) < 100):
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 10)
|
||||
hidden_size = random.randrange(1024, 3072, 512)
|
||||
inner_size = random.randrange(1024, 3072, 512)
|
||||
expert_num = random.randint(1, 40)
|
||||
topk = random.randint(1, expert_num)
|
||||
gated = random.choice([True, False])
|
||||
renormalize = random.choice([True, False])
|
||||
quant_wise = random.choice([128, 256, 512])
|
||||
data_type = random.choice([torch.bfloat16, torch.float16])
|
||||
if not mlu.is_bf16_supported():
|
||||
data_type = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode)
|
||||
if case in case_list:
|
||||
continue
|
||||
case_list.add(case)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
|
||||
self.__run_w4w8_mixed_case(*case, 0, -1)
|
||||
|
||||
def test_sq_fused_moe_single_quant_group(self):
|
||||
print("test_sq_fused_moe_single_quant_group")
|
||||
import random
|
||||
random.seed(0)
|
||||
act_mode = 'gelu'
|
||||
batch = 9
|
||||
seq = 10
|
||||
hidden_size = 1664
|
||||
inner_size = 512
|
||||
expert_num = 15
|
||||
topk = 2
|
||||
gated = False
|
||||
renormalize = False
|
||||
quant_bit = 4
|
||||
data_type = torch.float16
|
||||
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
|
||||
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
|
||||
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
|
||||
self.__run_quant_grouped_case(*case)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
|
||||
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16
|
||||
start_expert_id, expert_size = 0, 8
|
||||
hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
|
||||
router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32)
|
||||
residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
|
||||
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
|
||||
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
|
||||
args = (hidden_states, router_logit,
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, residual, None, None, None, None, None, None,
|
||||
topk, renormalize, gated, act_mode, 0, 0, 0)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
|
||||
|
||||
eps = 1e-5
|
||||
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
|
||||
weight1_shape, weight2_shape = weight1.shape, weight2.shape
|
||||
weight1 = weight1 / input_smooth.unsqueeze(1)
|
||||
weight2 = weight2 / act_smooth.unsqueeze(1)
|
||||
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
|
||||
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
|
||||
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
|
||||
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
|
||||
args = (hidden_states, router_logit,
|
||||
quant_w1, quant_w2, None, None, residual,
|
||||
input_smooth, act_smooth, w1_scale, w2_scale,
|
||||
None, None, topk, renormalize, gated, act_mode, 0, 0, 0)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFusedMOEOp))
|
||||
242
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_add_bias_activation.py
Executable file
242
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_add_bias_activation.py
Executable file
@@ -0,0 +1,242 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
import time
|
||||
import os
|
||||
|
||||
def gen_data(num_expert,
|
||||
total_tokens,
|
||||
inner_size,
|
||||
output_stride,
|
||||
dtype,
|
||||
is_gated,
|
||||
has_bias,
|
||||
is_ep):
|
||||
ci = inner_size * (1 + is_gated)
|
||||
input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
|
||||
cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
|
||||
output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
|
||||
output.as_strided(output.size(), (output_stride, 1))
|
||||
start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
|
||||
expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
|
||||
bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
|
||||
return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
|
||||
|
||||
class TestMoeActiveKernel(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, act_mode, is_gated, output, bias, cusum_token_count, start_expert_id, expert_size = args
|
||||
act_fun = torch.nn.functional.gelu if act_mode == 'gelu' else torch.nn.functional.silu
|
||||
total_token_num = input.size(0)
|
||||
inner_size = input.size(1) // 2 if is_gated else input.size(1)
|
||||
input_ = input.clone()
|
||||
if bias is not None:
|
||||
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
|
||||
input_ = input_[:deal_token_num, :]
|
||||
token_count = cusum_token_count[1:] - cusum_token_count[:-1]
|
||||
token_count_ = token_count[start_expert_id:start_expert_id+expert_size].tolist()
|
||||
input_list = list(input_.split(token_count_))
|
||||
for i in range(expert_size):
|
||||
input_list[i] += bias[i]
|
||||
input_ = torch.cat(input_list, dim=0)
|
||||
if cusum_token_count.size(0) - 1 != expert_size:
|
||||
pad = torch.zeros(total_token_num-deal_token_num, input_.size(-1)).to(input_.dtype).mlu()
|
||||
input_ = torch.cat((input_, pad), dim=0)
|
||||
acted = act_fun(input_[:, :inner_size])
|
||||
acted = acted * input_[:, inner_size:] if is_gated else acted
|
||||
if output is None:
|
||||
return acted
|
||||
else:
|
||||
return output.copy_(acted)
|
||||
|
||||
# 功能测试
|
||||
def test_functional(self):
|
||||
num_expert_list = [5, 32]
|
||||
total_tokens_list = [64, 1024]
|
||||
inner_size_list = [1024, 8192]
|
||||
dtype_list = [torch.half, torch.float32]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
is_gated_list = [True, False]
|
||||
is_ep_list = [True, False]
|
||||
# change True when test
|
||||
has_bias_list = [False, False, True]
|
||||
act_mode_list = ['silu', 'gelu']
|
||||
args = product(num_expert_list, total_tokens_list, inner_size_list, dtype_list,
|
||||
is_gated_list, is_ep_list, has_bias_list, act_mode_list)
|
||||
|
||||
for num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode in args:
|
||||
print("===============================================================================")
|
||||
print(f"num_expert: {num_expert}, total_tokens: {total_tokens}")
|
||||
print(f"inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}")
|
||||
print(f"is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}")
|
||||
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
base_output = torch.empty_like(output)
|
||||
ops.moe_active(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
output,
|
||||
bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
|
||||
output = output[:deal_token_num, :]
|
||||
self.op_impl_base(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
base_output,
|
||||
bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
base_output = base_output[:deal_token_num, :]
|
||||
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inplace(self):
|
||||
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
|
||||
32, 80, 1024, torch.half, True, False, False, 'gelu'
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
input_bak = input.clone()
|
||||
# test output inpalce with stride
|
||||
output = input.as_strided((total_tokens, inner_size), (inner_size * 2, 1))
|
||||
base_output = torch.empty_like(output)
|
||||
ops.moe_active(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
output,
|
||||
bias,
|
||||
None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
self.op_impl_base(input_bak, act_mode, is_gated, base_output, bias, None, start_expert_id, expert_size)
|
||||
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
# 随机遍历测试
|
||||
def test_random(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
|
||||
for i in range(500):
|
||||
num_expert = random.randint(1, 64)
|
||||
total_tokens = random.randint(1, 32768)
|
||||
inner_size = random.randint(1, 8192)
|
||||
dtype = random.sample(dtype_list, 1)[0]
|
||||
is_gated = random.sample([True, False], 1)[0]
|
||||
is_ep = random.sample([True, False], 1)[0]
|
||||
# change True when test
|
||||
has_bias = random.sample([False, True], 1)[0]
|
||||
act_mode = random.sample(['gelu', 'silu'], 1)[0]
|
||||
print("===============================================================================")
|
||||
print(f"[{i}]: num_expert: {num_expert}, total_tokens: {total_tokens}")
|
||||
print(f" inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}")
|
||||
print(f" is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}")
|
||||
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
base_output = torch.empty_like(output)
|
||||
ops.moe_active(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
output,
|
||||
bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
|
||||
output = output[:deal_token_num, :]
|
||||
self.op_impl_base(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
base_output,
|
||||
bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
base_output = base_output[:deal_token_num, :]
|
||||
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_single(self):
|
||||
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
|
||||
32, 64, 8192, torch.float, True, False, False, 'gelu'
|
||||
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
base_output = torch.empty_like(output)
|
||||
ops.moe_active(input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
output,
|
||||
bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
|
||||
output = output[:deal_token_num, :]
|
||||
|
||||
self.op_impl_base(input, act_mode, is_gated, base_output, bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size)
|
||||
base_output = base_output[:deal_token_num, :]
|
||||
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = ops.moe_active
|
||||
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
|
||||
32, 64, 8192, torch.float, True, False, False, 'gelu'
|
||||
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
act_mode = 'abc'
|
||||
self.assertException("act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.", func,
|
||||
input, act_mode, is_gated, output, bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id, expert_size)
|
||||
act_mode = 'gelu'
|
||||
self.assertException("input.dim() >= 2", func,
|
||||
input.reshape(-1), act_mode, is_gated, output, bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id, expert_size)
|
||||
is_gated = True
|
||||
self.assertException("in_channel % 2 == 0 if is_gated is true", func,
|
||||
torch.randn(total_tokens, inner_size * 2 - 1, dtype=dtype, device='mlu'),
|
||||
act_mode, is_gated, output, bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id, expert_size)
|
||||
|
||||
def test_inductor(self):
|
||||
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
|
||||
32, 64, 8192, torch.float, True, False, True, 'gelu'
|
||||
|
||||
input, bias, _, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
args = (input, output, bias_real,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
act_mode, is_gated, start_expert_id, expert_size)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestMoeActiveKernel))
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
class TestMoeCastGating(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, weight = args
|
||||
input = input.to(torch.float)
|
||||
output = torch.matmul(input, weight.permute(1, 0))
|
||||
return output
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
|
||||
def test_moe_cast_gating_random(self):
|
||||
for _ in range(1000):
|
||||
total_seq = random.randint(1, 32768)
|
||||
hidden_size = random.randint(1, 16384)
|
||||
expert_num = random.randint(1, 128)
|
||||
input_dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
input_dtype_list.append(torch.bfloat16)
|
||||
input_dtype = random.choice(input_dtype_list)
|
||||
weight_dtype = torch.float
|
||||
print("total_seqlen={}, hidden_size={}, expert_num={}, input_dtype={}, testing...".format(
|
||||
total_seq, hidden_size, expert_num, input_dtype))
|
||||
input = torch.randn(total_seq, hidden_size, dtype=input_dtype, device="mlu")
|
||||
weight = torch.randn(expert_num, hidden_size, dtype=weight_dtype, device="mlu")
|
||||
|
||||
tmo_out = ops.moe_cast_gating(input, weight)
|
||||
torch_out = self.op_impl_base(input, weight)
|
||||
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 1e-4,
|
||||
use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = ops.moe_cast_gating
|
||||
input = torch.randn(1024, 8192, dtype=torch.half, device="mlu")
|
||||
input = input.as_strided(input.shape, (100, 1))
|
||||
weight = torch.randn(128, 8192, dtype=torch.float, device="mlu")
|
||||
self.assertException("input must be contiguous", func, input, weight)
|
||||
input = input.contiguous()
|
||||
weight = weight.as_strided(weight.shape, (100, 1))
|
||||
self.assertException("weight must be contiguous", func, input, weight)
|
||||
weight = weight.contiguous()
|
||||
weight = weight.reshape(1, 128, 8192)
|
||||
self.assertException("weight.dim() == 2", func, input, weight)
|
||||
weight = torch.randn(128, 2048, dtype=torch.float, device="mlu")
|
||||
self.assertException("input.size(-1) == weight.size(-1)", func, input, weight)
|
||||
weight = torch.randn(128, 8192, dtype=torch.half, device="mlu")
|
||||
self.assertException("weight type need be torch::kFloat32", func, input, weight)
|
||||
weight = weight.to(torch.float)
|
||||
input = input.to(torch.float)
|
||||
self.assertException("input type need be torch::kFloat16 or torch::kBFloat16", func, input, weight)
|
||||
input = torch.randn(1024, 16388, dtype=torch.half, device="mlu")
|
||||
weight = torch.randn(128, 16388, dtype=torch.float, device="mlu")
|
||||
self.assertException("hidden_size > 0 && hidden_size <= 16384", func, input, weight)
|
||||
input = torch.randn(1024, 16384, dtype=torch.half, device="mlu")
|
||||
weight = torch.randn(129, 16384, dtype=torch.float, device="mlu")
|
||||
self.assertException("expert_num > 0 && expert_num <= 128", func, input, weight)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
|
||||
def test_inductor(self):
|
||||
m, hidden_size, expert_num, input_dtype = 1024, 4096, 32, torch.half
|
||||
input = torch.randn(m, hidden_size, dtype=input_dtype, device="mlu")
|
||||
weight = torch.randn(expert_num, hidden_size, dtype=torch.float, device="mlu")
|
||||
args = (input, weight)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.moe_cast_gating, args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "MLU3" not in torch.mlu.get_device_name():
|
||||
exit(run_unittest(TestMoeCastGating))
|
||||
242
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_combine_result.py
Normal file
242
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_combine_result.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
def generate_token_count(num_expert,
|
||||
total_token_count):
|
||||
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \
|
||||
dtype=torch.int32).to(dtype=torch.float32)
|
||||
sum = torch.sum(token_count, dim=-1) * 1.0
|
||||
token_count *= total_token_count / sum.item()
|
||||
token_count = token_count.to(dtype=torch.int32)
|
||||
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
|
||||
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
|
||||
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
|
||||
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
|
||||
cusum_token_count[-1] = total_token_count
|
||||
return cusum_token_count
|
||||
|
||||
|
||||
def gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
has_bias,
|
||||
has_residual,
|
||||
dtype,
|
||||
device):
|
||||
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
|
||||
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
|
||||
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
|
||||
bias = None
|
||||
residual = None
|
||||
|
||||
cusum_token_count = None
|
||||
|
||||
if has_bias:
|
||||
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
|
||||
if has_residual:
|
||||
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
|
||||
|
||||
if has_bias or expert_size < num_expert:
|
||||
cusum_token_count = generate_token_count(num_expert, num_tokens * topk)
|
||||
cusum_token_count = cusum_token_count.to(device=device)
|
||||
return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
|
||||
|
||||
class TestCombineResult(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, \
|
||||
expert_size, bias = args
|
||||
input = input.to(dtype=torch.float32).cpu()
|
||||
reduce_weight = reduce_weight.cpu()
|
||||
gather_ids = gather_ids.cpu()
|
||||
bias = None
|
||||
|
||||
dtype = input.dtype
|
||||
num_expert = expert_size
|
||||
hidden_size = input.shape[1]
|
||||
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=torch.float32).cpu()
|
||||
num_expert = bias.shape[0]
|
||||
if cusum_token_count is not None:
|
||||
num_expert = cusum_token_count.shape[0] - 1
|
||||
cusum_token_count = cusum_token_count.cpu()
|
||||
|
||||
if cusum_token_count is not None and expert_size < num_expert:
|
||||
gathered_input = input[gather_ids - cusum_token_count[start_expert_id].item()]
|
||||
else:
|
||||
gathered_input = input[gather_ids]
|
||||
|
||||
if bias is not None and cusum_token_count is not None:
|
||||
for i in range(start_expert_id, start_expert_id + expert_size):
|
||||
gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] += bias[i]
|
||||
gathered_input = gathered_input.reshape(*reduce_weight.shape, hidden_size)
|
||||
|
||||
if cusum_token_count is not None:
|
||||
filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
|
||||
(gather_ids < cusum_token_count[start_expert_id + expert_size])
|
||||
filtered_ids = filtered_ids.to(dtype=torch.float32)
|
||||
reduce_weight = reduce_weight * filtered_ids.reshape(reduce_weight.shape)
|
||||
|
||||
gathered_input *= reduce_weight.reshape(*reduce_weight.shape, -1)
|
||||
output = torch.sum(gathered_input, dim=1, keepdim=False)
|
||||
if residual is not None:
|
||||
residual = residual.to(dtype=torch.float32).cpu()
|
||||
output += residual
|
||||
return output.to(dtype=dtype)
|
||||
|
||||
def test_random_case(self):
|
||||
torch.manual_seed(444)
|
||||
test_cases = 200
|
||||
num_tokens_list = torch.randint(low=1, high=2048, size=(test_cases, ), dtype=torch.int32)
|
||||
topk_list = torch.randint(low=1, high=33, size=(test_cases, ), dtype=torch.int32)
|
||||
hidden_size_list = torch.randint(low=256, high=8193, size=(test_cases, ), dtype=torch.int32)
|
||||
num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
num_expert_list = torch.maximum(topk_list, num_expert_list)
|
||||
expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
expert_size_list = torch.minimum(expert_size_list, num_expert_list)
|
||||
start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list)
|
||||
start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32))
|
||||
has_bias_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
has_residual_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_map = [torch.half, torch.bfloat16, torch.float]
|
||||
|
||||
device = 'mlu'
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
is_mlu370 = "MLU3" in mlu_name
|
||||
max_num_tokens = 128 * 1024
|
||||
for i in range(test_cases):
|
||||
num_tokens = num_tokens_list[i].item()
|
||||
topk = topk_list[i].item()
|
||||
hidden_size = hidden_size_list[i].item()
|
||||
num_expert = num_expert_list[i].item()
|
||||
expert_size = expert_size_list[i].item()
|
||||
start_expert_id = start_expert_id_list[i].item()
|
||||
has_bias = False # has_bias_list[i].item()
|
||||
has_residual = has_residual_list[i].item()
|
||||
|
||||
topk = min(topk, (int)(max_num_tokens / num_tokens))
|
||||
|
||||
if dtype_list[i].item() < 4:
|
||||
dtype = dtype_map[0]
|
||||
elif dtype_list[i].item() < 9:
|
||||
dtype = dtype_map[1]
|
||||
else:
|
||||
dtype = dtype_map[2]
|
||||
|
||||
if is_mlu370 and dtype is torch.bfloat16:
|
||||
continue
|
||||
|
||||
inputs = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
has_bias,
|
||||
has_residual,
|
||||
dtype,
|
||||
device)
|
||||
input = inputs[0]
|
||||
reduce_weight = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
residual = inputs[3]
|
||||
bias = inputs[4]
|
||||
cusum_token_count = inputs[5]
|
||||
|
||||
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
|
||||
"start_expert_id={}, has_bias={}, has_residual={}, dtype={}, testing...".format(
|
||||
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
|
||||
has_bias, has_residual, dtype))
|
||||
|
||||
golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual,
|
||||
cusum_token_count, start_expert_id, expert_size, None)
|
||||
|
||||
output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003,
|
||||
"golden_output must equal output", True, True, True, True)
|
||||
|
||||
|
||||
def test_perf_case(self):
|
||||
num_tokens_list = [1, 72, 512]
|
||||
hidden_size_list = [2048, 4096, 5120, 8192]
|
||||
# [num_expert, topk, start_expert_id, expert_size]
|
||||
expert_options_list = [[1, 1, 0, 1], [8, 2, 0, 8], [32, 5, 0, 32], [32, 5, 24, 8]]
|
||||
has_residual_list = [True, False]
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
|
||||
device = 'mlu'
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
is_mlu370 = "MLU3" in mlu_name
|
||||
args = product(num_tokens_list, hidden_size_list, expert_options_list,\
|
||||
has_residual_list, dtype_list)
|
||||
for num_tokens, hidden_size, expert_options, has_residual, dtype in args:
|
||||
num_expert = expert_options[0]
|
||||
topk = expert_options[1]
|
||||
start_expert_id = expert_options[2]
|
||||
expert_size = expert_options[3]
|
||||
has_bias = False
|
||||
|
||||
if is_mlu370 and dtype is torch.bfloat16:
|
||||
continue
|
||||
|
||||
torch.manual_seed(444)
|
||||
inputs = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
has_bias,
|
||||
has_residual,
|
||||
dtype,
|
||||
device)
|
||||
input = inputs[0]
|
||||
reduce_weight = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
residual = inputs[3]
|
||||
bias = inputs[4]
|
||||
cusum_token_count = inputs[5]
|
||||
|
||||
golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual,
|
||||
cusum_token_count, start_expert_id, expert_size, None)
|
||||
|
||||
notify_start = torch.mlu.Event(enable_timing=True)
|
||||
notify_end = torch.mlu.Event(enable_timing=True)
|
||||
notify_start.record()
|
||||
loop = 10
|
||||
for _ in range(loop):
|
||||
output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
notify_end.record()
|
||||
notify_end.synchronize()
|
||||
time = notify_start.hardware_time(notify_end) / loop
|
||||
|
||||
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
|
||||
"start_expert_id={}, has_bias={}, has_residual={}, dtype={}, time={:.1f}".format(
|
||||
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
|
||||
has_bias, has_residual, dtype, time))
|
||||
|
||||
self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003,
|
||||
"golden_output must equal output", True, True, True, True)
|
||||
|
||||
def test_inductor(self):
|
||||
num_tokens, hidden_size, has_bias, has_residual, dtype = 1, 2048, False, True, torch.float16
|
||||
num_expert, topk, start_expert_id, expert_size = 8, 2, 0, 8
|
||||
input, reduce_weight, gather_ids, residual, \
|
||||
bias, cusum_token_count = gen_case(num_tokens, topk, hidden_size, num_expert, expert_size,
|
||||
has_bias, has_residual, dtype, 'mlu')
|
||||
args = (input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.moe_combine_result, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestCombineResult))
|
||||
132
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_expand_input.py
Normal file
132
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_expand_input.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import torch_mlu_ops
|
||||
import unittest
|
||||
import torch_mlu_ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
from typing import Tuple, Optional
|
||||
import os
|
||||
|
||||
class TestExpandInput(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
input = create_tensor_from_dic(dic['input'])
|
||||
gather_idx = dic['gather_idx']['data']
|
||||
cusum_token_count = dic['cusum_token_count']['data']
|
||||
start_expert_id = dic['start_expert_id']['data']
|
||||
expert_size = dic['expert_size']['data']
|
||||
self.launch(input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
def launch(self, *args):
|
||||
cusum_token_count = args[2]
|
||||
start_expert_id = args[3]
|
||||
expert_size = args[4]
|
||||
real_token_count = None if cusum_token_count is None else cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id]
|
||||
base_out = self.op_impl_base(*args)
|
||||
tmo_out = torch_mlu_ops.moe_expand_input(*args)
|
||||
self.assertTensorsEqual(base_out[:real_token_count].cpu().float(), tmo_out[:real_token_count].cpu().float(),
|
||||
0.00, use_MSE=True, use_RAE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
input, gather_idx, cusum_token_count, start_expert_id, expert_size = args
|
||||
if cusum_token_count is None:
|
||||
return input[gather_idx]
|
||||
else:
|
||||
idx = gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id + expert_size]]
|
||||
return input[idx]
|
||||
|
||||
def get_tensor(self, token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
|
||||
input = torch.randn(token_num, hidden_size, device='mlu').to(dtype)
|
||||
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu')
|
||||
cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
|
||||
cusum_token_count = cusum_token_count.to('mlu')
|
||||
use_all_experts = expert_num == expert_size
|
||||
if use_all_experts:
|
||||
cusum_token_count = None
|
||||
real_token_count = token_num * topk
|
||||
else:
|
||||
real_token_count = cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id]
|
||||
return input, gather_idx, cusum_token_count, real_token_count
|
||||
|
||||
def test_kernel_random(self):
|
||||
for i in range(100):
|
||||
token_num = random.randint(1, 2048)
|
||||
hidden_size = random.randint(1, 4096)
|
||||
expert_num = random.randint(1, 32)
|
||||
topk = random.randint(1, expert_num)
|
||||
start_expert_id = random.randint(0, expert_num-1)
|
||||
expert_size = random.randint(1, expert_num-start_expert_id)
|
||||
dtype_list = [torch.half, torch.float, torch.int8]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
dtype = random.sample(dtype_list, 1)[0]
|
||||
print("===============================================================================")
|
||||
print(f"[{i}]: token_num: {token_num}, hidden_size: {hidden_size}, expert_num: {expert_num}")
|
||||
print(f" topk: {topk}, start_expert_id: {start_expert_id}, expert_size: {expert_size}, dtype: {dtype}")
|
||||
input, gather_idx, cusum_token_count, real_token_count = self.get_tensor(token_num,
|
||||
hidden_size, expert_num, topk, start_expert_id, expert_size, dtype)
|
||||
base_expand_hidden_states = self.op_impl_base(input, gather_idx, None, 0, 0)
|
||||
expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx)
|
||||
self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(),
|
||||
0.00, use_MSE=True, use_RAE=True)
|
||||
del base_expand_hidden_states, expand_hidden_states
|
||||
base_expand_hidden_states = self.op_impl_base(input, gather_idx, cusum_token_count, start_expert_id, expert_size)[:real_token_count]
|
||||
expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count,
|
||||
start_expert_id, expert_size)[:real_token_count]
|
||||
self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(),
|
||||
0.00, use_MSE=True, use_RAE=True)
|
||||
del base_expand_hidden_states, expand_hidden_states
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = torch_mlu_ops.moe_expand_input
|
||||
token_num, hidden_size, topk, expert_num, start_expert_id, expert_size, dtype = 5, 5, 6, 32, 0, 16, torch.half
|
||||
input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \
|
||||
topk, start_expert_id, expert_size, dtype)
|
||||
input = torch.randn(token_num, 1, hidden_size).to(dtype).to('mlu')
|
||||
self.assertException("input dim must be equal to 2.", func, input, gather_idx)
|
||||
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
|
||||
gather_idx = torch.randint(0, token_num, (token_num * topk, 1)).to(torch.int32).to('mlu')
|
||||
self.assertException("gather_idx dim must be equal to 1.", func, input, gather_idx)
|
||||
gather_idx = gather_idx.reshape(token_num * topk)
|
||||
self.assertException("input tensor must on mlu.", func, input.cpu(), gather_idx)
|
||||
self.assertException("gather_idx must on mlu.", func, input, gather_idx.cpu())
|
||||
self.assertException("data type of gather_idx must be int32.", func, input, gather_idx.to(torch.int64))
|
||||
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
|
||||
gather_idx = torch.randint(0, token_num, (8,)).to(torch.int32).to('mlu')
|
||||
self.assertException("expand_token_num % token_num == 0.", func, input, gather_idx)
|
||||
gather_idx = torch.randint(1, token_num, (token_num * topk,)).to(torch.int32).to('mlu')
|
||||
self.assertException("cusum_token_count must on mlu.", func, input, gather_idx, cusum_token_count.cpu(),
|
||||
start_expert_id, expert_size)
|
||||
self.assertException("data type of cusum_token_count must be int32.", func, input, gather_idx, cusum_token_count.to(torch.int64),
|
||||
start_expert_id, expert_size)
|
||||
start_expert_id = -1
|
||||
self.assertException("start_expert_id >=0 && start_expert_id < expert_num.",
|
||||
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
start_expert_id = expert_num
|
||||
self.assertException("start_expert_id >=0 && start_expert_id < expert_num.",
|
||||
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
start_expert_id = 16
|
||||
expert_size = 17
|
||||
self.assertException("start_expert_id + expert_size <= expert_num.",
|
||||
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
def test_inductor(self):
|
||||
dtype = torch.float16
|
||||
token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 64, 128, 16, 8, 3, 10
|
||||
input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \
|
||||
topk, start_expert_id, expert_size, dtype)
|
||||
args = (input, gather_idx, cusum_token_count, start_expert_id, expert_size)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.moe_expand_input, args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestExpandInput))
|
||||
62
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_gen_idx.py
Normal file
62
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_gen_idx.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import torch_mlu_ops
|
||||
import unittest
|
||||
import torch_mlu_ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
from typing import Tuple, Optional
|
||||
import os
|
||||
|
||||
def gen_args(index: int = 0):
|
||||
token_num = random.randint(1, 32768)
|
||||
expert_num = random.randint(1, 256)
|
||||
topk = random.randint(1, expert_num)
|
||||
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
|
||||
print("===============================================================================")
|
||||
print(f"[{index}]: token_num: {token_num}, expert_num: {expert_num}, topk: {topk}")
|
||||
return expert_id, expert_num
|
||||
|
||||
class TestGenIdx(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
expert_id, expert_num = args
|
||||
token_num, topk = expert_id.size(0), expert_id.size(1)
|
||||
sorted_expert_id, indices = expert_id.int().flatten().sort()
|
||||
expand_idx_out = indices.int() // topk
|
||||
combine_idx_out = torch.zeros((token_num * topk,), dtype=torch.int, device="mlu")
|
||||
combine_idx_out.scatter_(0, indices, torch.arange(token_num * topk, dtype=torch.int, device="mlu"))
|
||||
token_count_out = torch.bincount(sorted_expert_id, minlength=expert_num)
|
||||
cusum_token_count_out = torch.cat((torch.tensor([0]).to('mlu'), torch.cumsum(token_count_out, dim=0))).to(torch.int32)
|
||||
return tuple([expand_idx_out, combine_idx_out, token_count_out, cusum_token_count_out])
|
||||
|
||||
def test_kernel_random(self):
|
||||
for i in range(1500):
|
||||
expert_id, expert_num = gen_args(i)
|
||||
base_gather_expand_idx, base_gather_combine_idx, base_token_count, base_cusum_token_count = self.op_impl_base(expert_id, expert_num)
|
||||
gather_expand_idx_out, gather_combine_idx_out, token_count_out, cusum_token_count_out = \
|
||||
torch_mlu_ops.moe_gen_idx(expert_id, expert_num)
|
||||
self.assertTensorsEqual(base_gather_expand_idx.cpu(), gather_expand_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_gather_combine_idx.cpu(), gather_combine_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_token_count.cpu(), token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_cusum_token_count.cpu(), cusum_token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
args = gen_args()
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.moe_gen_idx, args)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = torch_mlu_ops.moe_gen_idx
|
||||
token_num, expert_num, topk = 128, 32, 16
|
||||
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk, 1)).to(torch.int32).to('mlu')
|
||||
self.assertException("expert_id dim must be equal to 2.", func, expert_id, expert_num)
|
||||
expert_id = expert_id.reshape(token_num, topk)
|
||||
self.assertException("data type of expert_id must be int32.", func, expert_id.to(torch.int64), expert_num)
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestGenIdx))
|
||||
200
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_softmax_topk.py
Normal file
200
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_softmax_topk.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
import time
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class TestSoftmaxTopkOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, topk, normalize, num_expert_group, topk_group, origin_mask, normed_by = args
|
||||
softmax = torch.softmax(input.float(), dim=-1)
|
||||
if num_expert_group <= 1:
|
||||
if origin_mask is not None:
|
||||
softmax = softmax * origin_mask
|
||||
reduce_weight, expert_id = torch.topk(softmax, k=topk, dim=-1)
|
||||
if normalize:
|
||||
if normed_by == "topk_logit":
|
||||
reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True)
|
||||
if normed_by == "softmax_logit":
|
||||
reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True)
|
||||
return reduce_weight, expert_id
|
||||
else:
|
||||
group_size = softmax.shape[-1] // num_expert_group
|
||||
new_shape = softmax.shape[:-1] + (num_expert_group, group_size)
|
||||
group_data = softmax.view(new_shape)
|
||||
group_max_value = group_data.max(dim=-1).values
|
||||
group_idx = torch.topk(group_max_value, k=topk_group, dim=-1)[1]
|
||||
mask_shape = softmax.shape[:-1] + (num_expert_group,)
|
||||
mask = torch.zeros((mask_shape), dtype = torch.bool, device = group_idx.device)
|
||||
mask.scatter_(-1, group_idx, True)
|
||||
mask = mask.unsqueeze(-1).expand(new_shape)
|
||||
masked_data = group_data.masked_fill(~mask, 0.0)
|
||||
masked_data = masked_data.reshape(softmax.shape)
|
||||
reduce_weight, expert_id = torch.topk(masked_data, k=topk, dim=-1)
|
||||
if normalize:
|
||||
if normed_by == "topk_logit":
|
||||
reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True)
|
||||
if normed_by == "softmax_logit":
|
||||
reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True)
|
||||
return reduce_weight, expert_id
|
||||
|
||||
# 接口测试
|
||||
def test_interface(self):
|
||||
num_token, num_expert, topk, num_expert_group, topk_group, normalize = 1024, 160, 6, 10, 5, False
|
||||
input = torch.randn(num_token, num_expert, dtype=torch.half, device='mlu')
|
||||
mask = None
|
||||
normed_by = "topk_logit"
|
||||
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
|
||||
reduce_weight, expert_id = torch_mlu_ops.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group)
|
||||
base_expert_id, _ = base_expert_id.sort()
|
||||
expert_id, _ = expert_id.sort()
|
||||
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
|
||||
# 随机遍历测试
|
||||
def test_kernel_random(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for i in range(1000):
|
||||
num_batch = random.randint(1, 64)
|
||||
num_mask = random.randint(1, 1024)
|
||||
num_expert = random.randint(1, 512)
|
||||
factors = [i for i in range(1, num_expert + 1) if num_expert % i == 0]
|
||||
num_expert_group = random.choice(factors)
|
||||
topk_group = random.randint(1, num_expert_group)
|
||||
topk = random.randint(1, num_expert / num_expert_group * topk_group)
|
||||
normalize = random.sample([True, False], 1)[0]
|
||||
dtype = random.sample(dtype_list, 1)[0]
|
||||
group_invalid = random.sample([True, False], 1)[0]
|
||||
if group_invalid:
|
||||
num_expert_group = -1
|
||||
normed_by = random.choice(["topk_logit", "softmax_logit"])
|
||||
print("===============================================================================")
|
||||
print(f"[{i}]: num_batch: {num_batch}, num_mask: {num_mask}, num_expert: {num_expert}, num_expert_group: {num_expert_group}")
|
||||
print(f" topk_group: {topk_group}, topk: {topk}, normalize: {normalize}, dtype: {dtype}, normed_by: {normed_by}")
|
||||
input = torch.randn(num_batch, num_mask, num_expert, dtype=dtype).mlu()
|
||||
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = dtype).mlu()
|
||||
if num_expert_group > 1:
|
||||
mask = None
|
||||
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
|
||||
reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
normalize,
|
||||
mask,
|
||||
normed_by)
|
||||
# softmax后的值可能因数值差异较小,造成topk的值存在顺序上的差异,例如:
|
||||
# base_reduce_weight[N, 10:12] = [0.012, 0.011]
|
||||
# reduce_weight[N, 10:12] = [0.011, 0.012]
|
||||
# base_expert_id[N, 10:12] = [7, 8]
|
||||
# expert_id[N, 10:12] = [8, 7]
|
||||
# 这种情况产生的顺序上的差异不是错误的,但这样会造成结果对比错误,因此需要对结果先排序再对比结果
|
||||
base_expert_id, _ = base_expert_id.sort()
|
||||
expert_id, _ = expert_id.sort()
|
||||
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
# 防呆测试
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = torch.ops.torch_mlu_ops.moe_softmax_topk
|
||||
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 2, 1024, 160, 0, 1, 1, True, "abc_logit"
|
||||
input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu')
|
||||
input_permute = input.permute(0, 2, 1)
|
||||
mask = torch.randint(0, 2, (num_mask, num_expert), dtype = torch.half, device='mlu')
|
||||
mask_permute = mask.permute(1, 0)
|
||||
self.assertException("input must be contiguous.",
|
||||
func, input_permute, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
self.assertException("mask must be contiguous.",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask_permute, normed_by)
|
||||
self.assertException("normed_by must be 'topk_logit' or 'softmax_logit'",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
input = torch.randn(num_expert, dtype = torch.half, device = 'mlu')
|
||||
normed_by = "softmax_logit"
|
||||
self.assertException("input.dim() >= 2",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu')
|
||||
self.assertException("topk > 0 && topk <= num_expert",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
topk = 5
|
||||
self.assertException("the dim of mask should be the same as input",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = torch.randint(0, 2, (1, num_mask, num_expert - 1), dtype = torch.half, device='mlu')
|
||||
self.assertException("the last dim of mask should be the same as the last dim of input",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = torch.randint(0, 2, (1, num_mask - 1, num_expert), dtype = torch.half, device='mlu')
|
||||
self.assertException("the penultimate dim of mask should be the same as the penultimate dim of input",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = torch.randint(0, 2, (2, num_mask, num_expert), dtype = torch.half, device='mlu')
|
||||
self.assertException("the product of all but the lower two dimensions of mask is 1",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.float, device='mlu')
|
||||
self.assertException("the dtype of mask should be the same as input",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu')
|
||||
num_expert_group = 8
|
||||
self.assertException("if num_expert_group > 1, mask should be None",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
mask = None
|
||||
topk = 160
|
||||
self.assertException("topk <= (num_expert / num_expert_group) * topk_group",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
|
||||
num_expert_group = 11
|
||||
self.assertException("num_expert % num_expert_group == 0",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
|
||||
num_expert_group = 8
|
||||
topk_group = 9
|
||||
self.assertException("topk_group > 0 && topk_group <= num_expert_group",
|
||||
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
|
||||
# 单条测例
|
||||
def test_single(self):
|
||||
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 1, 1, True, "softmax_logit"
|
||||
input = torch.randn(2, num_batch, num_mask, num_expert, dtype = torch.float32, device='mlu')
|
||||
mask = torch.randint(0, 2, (1, 1, num_mask, num_expert), dtype = torch.float32, device='mlu')
|
||||
# mask = None
|
||||
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
|
||||
reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
normalize,
|
||||
mask,
|
||||
normed_by)
|
||||
base_expert_id, _ = base_expert_id.sort()
|
||||
expert_id, _ = expert_id.sort()
|
||||
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 4, 3, True, "softmax_logit"
|
||||
input = torch.randn(num_batch, num_mask, num_expert, dtype=torch.half, device='mlu')
|
||||
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu')
|
||||
normed_by = "softmax_logit"
|
||||
if num_expert_group > 1:
|
||||
mask = None
|
||||
args = (input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.moe_softmax_topk, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestSoftmaxTopkOp))
|
||||
194
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_linear_cache.py
Executable file
194
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_linear_cache.py
Executable file
@@ -0,0 +1,194 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
from common_utils import *
|
||||
|
||||
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
|
||||
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True)
|
||||
cache_scale = args[10]
|
||||
args = args[0:10]
|
||||
args.insert(4, cache_scale[0])
|
||||
args.insert(5, cache_scale[1])
|
||||
args.insert(8, False)
|
||||
return args
|
||||
|
||||
class TestOfflineQuantToLinearCache(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
def quant2Int8(input_fp: torch.Tensor,
|
||||
scale_fp: torch.tensor,
|
||||
quant_mode: torch.int64):
|
||||
head_num = input_fp.size(0)
|
||||
seq = input_fp.size(1)
|
||||
head_size = input_fp.size(2)
|
||||
input_fp32 = input_fp.to(torch.float32)
|
||||
if quant_mode == 0: # per_channel
|
||||
scale = scale_fp.reshape((head_num, 1, head_size))
|
||||
else:
|
||||
scale = scale_fp.reshape((head_num, seq, 1))
|
||||
scaled_context = input_fp32 / scale
|
||||
rounded = torch.round(scaled_context)
|
||||
clipped = torch.clip(rounded, -128, 127)
|
||||
return clipped.to(torch.int8)
|
||||
|
||||
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \
|
||||
context_lengths, max_context_len, quant_mode, packed, context_seq_offset, cache_bs_id, \
|
||||
cache_seqlen_offset = args
|
||||
batch_size = context_lengths.size(0) - 1 if packed else context_lengths.size(0)
|
||||
for i in range(batch_size):
|
||||
if packed:
|
||||
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
|
||||
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i+1] - context_lengths[i]
|
||||
context_seq_offset_i = 0
|
||||
else:
|
||||
key_i = key[i].transpose(1, 0)
|
||||
value_i = value[i].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i]
|
||||
context_seq_offset_i = context_seq_offset[i] if context_seq_offset is not None else 0
|
||||
cache_bs_id_i = cache_bs_id[i] if cache_bs_id is not None else i
|
||||
cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
|
||||
if cache_bs_id_i < 0 or cache_seq_begin < 0:
|
||||
continue
|
||||
cache_seq_end = cache_seq_begin + context_len_i
|
||||
|
||||
# quant key to int8
|
||||
key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
if quant_mode == 0:
|
||||
key_cache_scale_i = key_cache_quant_scale
|
||||
else:
|
||||
key_cache_scale_i = key_cache_quant_scale[:, cache_seq_begin:cache_seq_end]
|
||||
quant_key_i = quant2Int8(key_i, key_cache_scale_i, quant_mode)
|
||||
key_cache_i = key_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end]
|
||||
key_cache_i[...] = quant_key_i
|
||||
|
||||
# quant value to int8
|
||||
if value_cache is not None and value is not None and value_cache_quant_scale is not None:
|
||||
value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
if quant_mode == 0:
|
||||
value_cache_scale_i = value_cache_quant_scale
|
||||
else:
|
||||
value_cache_scale_i = value_cache_quant_scale[:, cache_seq_begin:cache_seq_end]
|
||||
quant_value_i = quant2Int8(value_i, value_cache_scale_i, quant_mode)
|
||||
value_cache_i = value_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end]
|
||||
value_cache_i[...] = quant_value_i
|
||||
return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale)
|
||||
|
||||
def test_offline_quant_to_linear_cache(self):
|
||||
test_cases = 100
|
||||
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list *= 16
|
||||
cache_memory_len_list = torch.randint(low=2, high=1024, size=(test_cases, ), dtype=torch.int32)
|
||||
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
mode_list= torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_map = [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
for i in range(test_cases):
|
||||
q_heads = 1
|
||||
batch_size = bs_list[i].item()
|
||||
invalid_batch = batch_size // 10
|
||||
num_heads = num_heads_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
cache_memory_len = cache_memory_len_list[i].item()
|
||||
packed = packed_list[i].item()
|
||||
quant_mode = mode_list[i].item()
|
||||
total_heads = q_heads + num_heads * 2
|
||||
dtype = dtype_map[dtype_list[i]]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name and dtype == torch.bfloat16:
|
||||
dtype = torch.half
|
||||
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
|
||||
|
||||
print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, mode ={}, dtype={} testing...".format(
|
||||
batch_size, num_heads, head_size, cache_memory_len, packed >
|
||||
0, quant_mode, dtype))
|
||||
|
||||
torch.manual_seed(1)
|
||||
max_bs = batch_size + 1
|
||||
context_lens = torch.randint(size=(batch_size, ), low=1,
|
||||
high=cache_memory_len // 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
max_context_len = context_lens.max().item()
|
||||
max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch_size, ),
|
||||
low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch_size, ), low=0,
|
||||
high=(cache_memory_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
if packed > 0:
|
||||
context = torch.randn((total_seqlen, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
context = context.to(dtype)
|
||||
key = context[..., q_heads:q_heads + num_heads, :]
|
||||
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
|
||||
|
||||
# prepare key_scale and value_scale
|
||||
if quant_mode == 0 : # per_channel
|
||||
cache_scale = torch.randn((2, num_heads, head_size), dtype=torch.float, device='mlu')
|
||||
else:
|
||||
cache_scale = torch.randn((2, num_heads, cache_memory_len), dtype=torch.float, device='mlu')
|
||||
key_cache_scale = cache_scale[0]
|
||||
value_cache_scale = cache_scale[1]
|
||||
|
||||
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
|
||||
cache = (cache - 0.5) * 256
|
||||
cache = cache.to(torch.int8)
|
||||
|
||||
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
|
||||
ref_cache = cache.clone()
|
||||
ref_key_cache = ref_cache[0]
|
||||
ref_value_cache = ref_cache[1]
|
||||
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
|
||||
|
||||
if packed > 0:
|
||||
ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, cu_context_lens, max_context_len,
|
||||
quant_mode, packed > 0, None, cache_bs_id,
|
||||
cache_seq_offsets)
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale,
|
||||
value_cache_scale, cu_context_lens, max_context_len,
|
||||
quant_mode, packed > 0, None, cache_bs_id,
|
||||
cache_seq_offsets)
|
||||
else:
|
||||
ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lens, max_context_len,
|
||||
quant_mode, packed > 0, context_seq_offsets, cache_bs_id,
|
||||
cache_seq_offsets)
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lens, max_context_len,
|
||||
quant_mode, packed > 0, context_seq_offsets, cache_bs_id,
|
||||
cache_seq_offsets)
|
||||
# for debug
|
||||
cache = cache.cpu().flatten()
|
||||
ref_cache = ref_cache.cpu().flatten()
|
||||
diff = cache - ref_cache
|
||||
diff = diff.abs()
|
||||
assert torch.max(diff) < 2, "ref_cache must equal cache or absolute values differ by 1 due to round_mode!"
|
||||
|
||||
def test_inductor(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args)
|
||||
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestOfflineQuantToLinearCache))
|
||||
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
from common_utils import *
|
||||
import numpy as np
|
||||
|
||||
def quant2int8(input_fp: torch.Tensor,
|
||||
scale_fp: torch.Tensor):
|
||||
input_fp32 = input_fp.to(torch.float32)
|
||||
scaled_input = input_fp32 / scale_fp
|
||||
rounded = torch.round(scaled_input)
|
||||
clipped = torch.clip(rounded, -128, 127)
|
||||
return clipped.to(torch.int8)
|
||||
|
||||
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
|
||||
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True)
|
||||
cache_scale = args[10]
|
||||
slot_mapping = args[11]
|
||||
args = args[0:4]
|
||||
args.insert(2, cache_scale[0])
|
||||
args.insert(3, cache_scale[1])
|
||||
args.insert(4, slot_mapping)
|
||||
return args
|
||||
class TestOfflineQuantToPagedCache(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
k, v, k_cache_scale, v_cache_scale, slot_mapping, k_cache, v_cache = args
|
||||
tokens_num = k.shape[0]
|
||||
block_size = k_cache.shape[2]
|
||||
for i in range(tokens_num):
|
||||
if slot_mapping[i] >= 0:
|
||||
key_i = k[i]
|
||||
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
key_cache_i = k_cache[block_id, :, block_offset, :]
|
||||
quant_key_i = quant2int8(key_i, k_cache_scale)
|
||||
key_cache_i[...] = quant_key_i
|
||||
if v is not None:
|
||||
value_i = v[i]
|
||||
value_cache_i = v_cache[block_id, :, block_offset, :]
|
||||
quant_value_i = quant2int8(value_i, v_cache_scale)
|
||||
value_cache_i[...] = quant_value_i
|
||||
return (k_cache, v_cache) if v is not None else k_cache
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
|
||||
def test_offline_quant_to_paged_cache(self):
|
||||
test_cases = 10
|
||||
token_list = torch.randint(low=1, high=512, size=(test_cases, ), dtype=torch.int32)
|
||||
head_num_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
block_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_map = [torch.bfloat16, torch.half, torch.float32]
|
||||
only_quant_key_list = [True, False]
|
||||
|
||||
for i in range(test_cases):
|
||||
tokens_num = token_list[i].item()
|
||||
head_num = head_num_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
block_size = block_size_list[i].item()
|
||||
dtype = dtype_map[dtype_list[i]]
|
||||
print("tokens_num={}, head_num={}, head_size={}, block_size={}, dtype={} testing...".format(
|
||||
tokens_num, head_num, head_size, block_size, dtype))
|
||||
np.random.seed(1)
|
||||
only_quant_key = random.choice(only_quant_key_list)
|
||||
key_data = np.random.uniform(-1, 1, size=[tokens_num, head_num, head_size])
|
||||
key_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size])
|
||||
key = torch.tensor(key_data, dtype=dtype, device="mlu")
|
||||
key_cache_scale = torch.tensor(key_cache_scale_data, dtype=torch.float32, device="mlu")
|
||||
value_data = np.random.uniform(-0.25, 0.25, size=[tokens_num, head_num, head_size])
|
||||
value_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size])
|
||||
value = torch.tensor(value_data, dtype=dtype, device="mlu")
|
||||
value_cache_scale = torch.tensor(value_cache_scale_data, dtype=torch.float32, device="mlu")
|
||||
|
||||
min_blocks = (int)((tokens_num + block_size - 1) / block_size)
|
||||
blocks_num = min(min_blocks + 10, 2 * min_blocks)
|
||||
num_slots = blocks_num * block_size
|
||||
slot_mapping = random.sample(range(num_slots), tokens_num)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
|
||||
slot_mapping[-1] = -1 # test mask
|
||||
|
||||
torch.manual_seed(0)
|
||||
key_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu()
|
||||
|
||||
#python base result
|
||||
key_cache_base = key_cache.clone()
|
||||
value_cache_base = value_cache.clone()
|
||||
if only_quant_key:
|
||||
value, value_cache_scale, value_cache_base, value_cache = None, None, None, None
|
||||
self.op_impl_base(key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache_base, value_cache_base)
|
||||
#mlu result
|
||||
ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
#compute diff
|
||||
baseline_key_cache = key_cache_base.cpu().flatten()
|
||||
mlu_key_cache = key_cache.cpu().flatten()
|
||||
key_cache_diff = (mlu_key_cache - baseline_key_cache).abs()
|
||||
assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold"
|
||||
if not only_quant_key:
|
||||
baseline_value_cache = value_cache_base.cpu().flatten()
|
||||
mlu_value_cache = value_cache.cpu().flatten()
|
||||
value_cache_diff = (mlu_value_cache - baseline_value_cache).abs()
|
||||
assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold"
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_large_tensor(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
print("offline_quant_to_paged_cache: test_large_tensor...")
|
||||
head_num = 16
|
||||
head_size = 128
|
||||
token_nums = 200
|
||||
block_size = 16
|
||||
block_nums = ((2**32 - 1) // 1 // head_num // head_size // block_size)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
|
||||
key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
|
||||
value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
|
||||
key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
|
||||
value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
|
||||
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
|
||||
#python base result
|
||||
key_cache_base = key_cache.clone()
|
||||
value_cache_base = value_cache.clone()
|
||||
self.op_impl_base(key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache_base, value_cache_base)
|
||||
#mlu result
|
||||
ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
#compute diff
|
||||
baseline_key_cache = key_cache_base.cpu().flatten()
|
||||
mlu_key_cache = key_cache.cpu().flatten()
|
||||
key_cache_diff = (mlu_key_cache - baseline_key_cache).abs()
|
||||
assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold"
|
||||
|
||||
baseline_value_cache = value_cache_base.cpu().flatten()
|
||||
mlu_value_cache = value_cache.cpu().flatten()
|
||||
value_cache_diff = (mlu_value_cache - baseline_value_cache).abs()
|
||||
assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold"
|
||||
|
||||
block_nums = (2**32 // 1 // head_num // head_size // block_size)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
|
||||
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.offline_quant_to_paged_cache,
|
||||
key, value, key_cache_scale, value_cache_scale, slot_mapping, key_cache, value_cache)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
print("offline_quant_to_paged_cache: test prevent...")
|
||||
head_num = 16
|
||||
head_size = 128
|
||||
token_nums = 200
|
||||
block_size = 16
|
||||
block_nums = 20
|
||||
num_slots = block_nums * block_size
|
||||
dtype = torch.half
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
|
||||
key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
|
||||
value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
|
||||
key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
|
||||
value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
|
||||
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache = None
|
||||
self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value = value.reshape(token_nums, head_num, head_size, 1)
|
||||
self.assertException("dim of v must be 3",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
value = value.squeeze()
|
||||
value_cache = value_cache.reshape(block_nums, head_num, block_size, head_size, 1)
|
||||
self.assertException("dim of v_cache must be 4",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
value_cache = value_cache.squeeze()
|
||||
value_cache_scale = value_cache_scale.reshape(head_num, head_size, 1)
|
||||
self.assertException("dim of v_cache_scale must be 2",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
value_cache_scale = value_cache_scale.squeeze()
|
||||
value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, 1, head_size))
|
||||
self.assertException("v last dim must be contiguous.",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, token_nums, 1))
|
||||
self.assertException("v second dim must be contiguous.",
|
||||
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
|
||||
slot_mapping, key_cache, value_cache)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
|
||||
def test_inductor(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_paged_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestOfflineQuantToPagedCache))
|
||||
72
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_per_token_smooth_quantize.py
Executable file
72
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_per_token_smooth_quantize.py
Executable file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as tmo
|
||||
from common_utils import *
|
||||
import random
|
||||
|
||||
|
||||
class TestPerTokenSmoothQuantizeOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
x, smooth, zero, m_list = args
|
||||
x_shape = x.size()
|
||||
scale_shape = x.size()[0:-1]
|
||||
if m_list is None:
|
||||
smoothed = x * smooth
|
||||
else:
|
||||
input_list = x.split(tuple(m_list))
|
||||
experts = len(input_list)
|
||||
result = []
|
||||
for i in range(experts):
|
||||
result.append(input_list[i] * smooth[i])
|
||||
smoothed = torch.concat(result, dim=0)
|
||||
output, scale = QuantByRow(smoothed, 8)
|
||||
return output.reshape(x_shape), scale.reshape(scale_shape)
|
||||
|
||||
def test_random_case(self):
|
||||
torch.manual_seed(0)
|
||||
case_list = set()
|
||||
while(len(case_list) < 200):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
dtype = random.choice(dtype_list)
|
||||
has_group = random.choice([False, True])
|
||||
if has_group:
|
||||
experts = random.randint(1, 40)
|
||||
m_list = torch.randint(1, 100, (experts,), device="mlu", dtype=torch.int32)
|
||||
ci = m_list.sum().item()
|
||||
else:
|
||||
experts = None
|
||||
ci = random.randint(1, 4096)
|
||||
co = random.randint(1, 4096)
|
||||
case = (experts, ci, co, dtype)
|
||||
if case in case_list:
|
||||
continue
|
||||
else:
|
||||
case_list.add(case)
|
||||
x = torch.randn(ci, co, device="mlu", dtype=dtype)
|
||||
if has_group:
|
||||
scale = torch.randn(experts, co, device="mlu", dtype=torch.float32)
|
||||
else:
|
||||
scale = torch.randn(co, device="mlu", dtype=torch.float32)
|
||||
print("experts={}, ci={}, co={}, dtype={}, testing...".format(experts, ci, co, dtype), flush=True)
|
||||
param = (x, scale, None, m_list if has_group else None)
|
||||
tmo_output, tmo_scale = tmo.per_token_smooth_quantize(*param)
|
||||
torch_output, torch_scale = self.op_impl_base(*param)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
m_list = torch.randint(1, 100, (8,), device="mlu", dtype=torch.int32)
|
||||
total_m = m_list.sum().item()
|
||||
x = torch.randn(total_m, 1024, device="mlu", dtype=torch.half)
|
||||
scale = torch.randn(8, 1024, device="mlu", dtype=torch.float32)
|
||||
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
|
||||
output_scale = torch.empty(x.size()[:-1], dtype=torch.float32, device="mlu")
|
||||
args = (x, scale, output, output_scale, None, m_list, None, None, 'per_token', True)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestPerTokenSmoothQuantizeOp))
|
||||
21
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py
Normal file
21
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
|
||||
class TestPreloadOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
wegiht, size = args
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
def test_preload(self):
|
||||
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
|
||||
ops.preload(weight, weight.element_size() * weight.numel())
|
||||
torch.mlu.synchronize()
|
||||
|
||||
def test_inductor(self):
|
||||
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel()))
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestPreloadOp))
|
||||
137
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_ffn.py
Executable file
137
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_ffn.py
Executable file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
active = torch.nn.functional.silu
|
||||
|
||||
def torch_ffn(input, w1, scale1, bias1, w2, scale2, bias2, is_gated):
|
||||
tmp_input = input.flatten(0, -2).to(torch.float)
|
||||
inner_size = w1.size(0) // (1 + is_gated)
|
||||
imm = weight_only_quant_matmul(tmp_input, w1, scale1, bias1)
|
||||
acted = active(imm[:, :inner_size])
|
||||
acted = acted * imm[:, inner_size:] if is_gated else acted
|
||||
out = weight_only_quant_matmul(acted, w2, scale2, bias2)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def torch_smooth_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, input_smooth, act_smooth, is_gated):
|
||||
inner_size = w1.size(0) // (1 + is_gated)
|
||||
tmp_input = input.flatten(0, -2).to(torch.float)
|
||||
quant_input, input_scale = QuantByRow(tmp_input.flatten(0, -2) * input_smooth, 8)
|
||||
imm = smooth_quant_matmul(quant_input, input_scale, w1, scale1, input.dtype, bias1)
|
||||
acted = active(imm[:, :inner_size])
|
||||
acted = acted * imm[:, inner_size:] if is_gated else acted
|
||||
quant_acted, acted_scale = QuantByRow(acted.flatten(0, -2) * act_smooth, 8)
|
||||
out = smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, input.dtype, bias2)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def tmo_weight_only_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
|
||||
tmp_input = input.flatten(0, -2)
|
||||
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'silu', quant_bit, True)
|
||||
out = ops.weight_only_quant_matmul(imm, w2, scale2, None, bias2, None, "none", quant_bit)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def tmo_weight_only_group_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
|
||||
tmp_input = input.flatten(0, -2)
|
||||
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit)
|
||||
acted = active(imm)
|
||||
out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def tmo_weight_only_quant_gated_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
|
||||
tmp_input = input.flatten(0, -2)
|
||||
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit)
|
||||
acted = ops.active(imm, 'silu', True)
|
||||
out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def tmo_pertoken_smooth_quant_gated_ffn(input, w1, scale1, bias13, w2, scale2, bias2, input_smooth, act_smooth, dtype):
|
||||
tmp_input = input.flatten(0, -2)
|
||||
quant_input, input_scale = ops.per_token_smooth_quantize(tmp_input, input_smooth, None)
|
||||
imm = ops.smooth_quant_matmul(quant_input, input_scale, w1, scale1, dtype, bias13)
|
||||
acted = ops.active(imm, 'silu', True)
|
||||
quant_acted, acted_scale = ops.per_token_smooth_quantize(acted, act_smooth, None)
|
||||
out = ops.smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, dtype, bias2)
|
||||
return out.reshape(input.shape)
|
||||
|
||||
def init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=torch.half, group_num=1):
|
||||
sigma = 0.1
|
||||
eps = 0.1 # Avoid the occurrence of nan
|
||||
torch.manual_seed(1)
|
||||
input_smooth = torch.randn(hidden_size, dtype=torch.float, device="mlu").abs() + eps
|
||||
act_smooth = torch.randn(inner_size, dtype=torch.float, device="mlu").abs() + eps
|
||||
w1 = torch.randn((1 + is_gated) * inner_size, hidden_size, dtype=dtype, device="mlu") * sigma
|
||||
w1 = w1 / input_smooth if with_smooth else w1
|
||||
bias1 = torch.randn((1 + is_gated) * inner_size, dtype=dtype, device="mlu") * sigma
|
||||
w2 = torch.randn(hidden_size, inner_size, dtype=dtype, device="mlu") * sigma
|
||||
w2 = w2 / act_smooth if with_smooth else w2
|
||||
bias2 = torch.randn(hidden_size, dtype=dtype, device="mlu") * sigma
|
||||
input = torch.randn(batch, seq, hidden_size, dtype=dtype, device="mlu")
|
||||
quant_w1, scale1 = QuantByRow(w1, quant_bit, group_num)
|
||||
quant_w2, scale2 = QuantByRow(w2, quant_bit, group_num)
|
||||
return input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth
|
||||
|
||||
dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
class TestQuantFFN(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
def test_weight_only_quant_ffn(self):
|
||||
for dtype in dtype_list:
|
||||
print("test_weight_only_quant_ffn...")
|
||||
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 8, False, False
|
||||
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
|
||||
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype)
|
||||
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, False)
|
||||
tmo_out = tmo_weight_only_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, quant_bit)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_weight_only_quant_gated_ffn(self):
|
||||
for dtype in dtype_list:
|
||||
print("test_weight_only_quant_gated_ffn...")
|
||||
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 4, True, False
|
||||
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
|
||||
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype)
|
||||
quant_w1_int4 = PairlyPackInt8(quant_w1)
|
||||
quant_w2_int4 = PairlyPackInt8(quant_w2)
|
||||
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, True)
|
||||
tmo_out = tmo_weight_only_quant_gated_ffn(input, quant_w1_int4, scale1, bias1, quant_w2_int4, scale2, bias2, quant_bit)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_weight_only_group_quant_ffn(self):
|
||||
for dtype in dtype_list:
|
||||
print("test_weight_only_group_quant_ffn...")
|
||||
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 512, 8, False, False
|
||||
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
|
||||
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype, group_num=8)
|
||||
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, is_gated)
|
||||
tmo_out = tmo_weight_only_group_quant_ffn(input, quant_w1, scale1.to(dtype=input.dtype), bias1,
|
||||
quant_w2, scale2.to(dtype=input.dtype), bias2, quant_bit)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.05, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_pertoken_smooth_quant_ffn(self):
|
||||
for dtype in dtype_list:
|
||||
print("test_pertoken_smooth_quant_ffn...")
|
||||
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype = 3, 5, 512, 768, 8, True, True, dtype
|
||||
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth = \
|
||||
init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype)
|
||||
torch_out = torch_smooth_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2,
|
||||
input_smooth, act_smooth, is_gated)
|
||||
tmo_out = tmo_pertoken_smooth_quant_gated_ffn(input, quant_w1, scale1, bias1, quant_w2,
|
||||
scale2, bias2, input_smooth, act_smooth, dtype)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.05, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestQuantFFN))
|
||||
275
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_linear_cache.py
Executable file
275
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_linear_cache.py
Executable file
@@ -0,0 +1,275 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
from common_utils import *
|
||||
from typing import Optional
|
||||
|
||||
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, quant_bit):
|
||||
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True)
|
||||
cache_scale = args[10]
|
||||
args = args[0:10]
|
||||
args.insert(4, cache_scale[0])
|
||||
args.insert(5, cache_scale[1])
|
||||
args.insert(12, quant_bit)
|
||||
return args
|
||||
|
||||
class TestQuantToLinearCache(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \
|
||||
context_lengths, max_context_len, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset, \
|
||||
quant_bit = args
|
||||
def quant(context: torch.Tensor,
|
||||
quant_bit: int,
|
||||
group_size: int):
|
||||
# context:[head_num, seq, head_size]
|
||||
head_num = context.shape[0]
|
||||
head_size = context.shape[-1]
|
||||
if group_size != head_size:
|
||||
context = context.reshape(head_num, -1, group_size)
|
||||
context_fp32 = context.to(torch.float32)
|
||||
max_value, _ = torch.max(context_fp32.abs(), dim=-1, keepdim=True)
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
scale = max_value / int_max
|
||||
scaled_context = context_fp32 / scale
|
||||
return scaled_context.reshape(head_num, -1, head_size), scale[..., 0]
|
||||
|
||||
batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0]
|
||||
head_num = key.shape[-2]
|
||||
head_size = key.shape[-1]
|
||||
|
||||
if key_cache_quant_scale.dim() == 3:
|
||||
group_size = head_size
|
||||
else:
|
||||
group_size = head_size // key_cache_quant_scale.size(-1)
|
||||
for i in range(batch_size):
|
||||
if packed:
|
||||
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
|
||||
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i+1] - context_lengths[i]
|
||||
context_seq_offset_i = 0
|
||||
else:
|
||||
key_i = key[i].transpose(1, 0)
|
||||
value_i = value[i].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i]
|
||||
context_seq_offset_i = context_seq_offset[i]
|
||||
cache_bs_id_i = cache_bs_id[i]
|
||||
cache_seqlen_offset_i = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
|
||||
if cache_bs_id_i < 0 or cache_seqlen_offset_i < 0:
|
||||
continue
|
||||
key_cache_i = \
|
||||
key_cache[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
key_cache_scale_i = \
|
||||
key_cache_quant_scale[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
|
||||
# key_i[head_num, context_len[i], head_size]
|
||||
key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
float_key_i, key_scale_i = quant(key_i, quant_bit, group_size)
|
||||
key_cache_scale_i = key_cache_scale_i.reshape(key_cache_scale_i.shape[0], -1)
|
||||
key_cache_scale_i[...] = key_scale_i
|
||||
rounded = torch.round(float_key_i)
|
||||
clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1)
|
||||
quant_key_i = clipped.to(torch.int8)
|
||||
if quant_bit == 4:
|
||||
quant_key_flat = quant_key_i.flatten()
|
||||
d0 = quant_key_flat[0::2].to(torch.uint8)
|
||||
d1 = quant_key_flat[1::2].to(torch.uint8)
|
||||
dp = (d1 << 4) + (d0 & 0x0F)
|
||||
quant_key_i = dp.to(torch.int8).reshape(head_num, -1, head_size // 2)
|
||||
key_cache_i[...] = quant_key_i
|
||||
|
||||
if value_cache is not None and value is not None:
|
||||
value_cache_scale_i = \
|
||||
value_cache_quant_scale[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
float_value_i, value_scale_i = quant(value_i, quant_bit, group_size)
|
||||
value_cache_scale_i = value_cache_scale_i.reshape(value_cache_scale_i.shape[0], -1)
|
||||
value_cache_scale_i[...] = value_scale_i
|
||||
rounded = torch.round(float_value_i)
|
||||
clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1)
|
||||
quant_value_i = clipped.to(torch.int8)
|
||||
if quant_bit == 8:
|
||||
value_cache_i = \
|
||||
value_cache[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
else:
|
||||
value_cache_i = \
|
||||
value_cache[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i // 2:math.ceil((cache_seqlen_offset_i + context_len_i) / 2)]
|
||||
if cache_seqlen_offset_i % 2 == 1:
|
||||
front_vec = value_cache[cache_bs_id_i, :, cache_seqlen_offset_i // 2, :]
|
||||
front_low_bits = front_vec & (0x0F)
|
||||
front_low_bits_expand = front_low_bits.unsqueeze(1) # [head_num, 1, head_size]
|
||||
quant_value_i = torch.cat((front_low_bits_expand, quant_value_i), dim=1)
|
||||
if (cache_seqlen_offset_i + context_len_i) % 2 == 1:
|
||||
back_vec = value_cache[cache_bs_id_i, :, math.ceil((cache_seqlen_offset_i + context_len_i) / 2) - 1, :]
|
||||
back_high_bits = (back_vec >> 4) & (0x0F)
|
||||
back_high_bits_expand = back_high_bits.unsqueeze(1) # [head_num, 1, head_size]
|
||||
quant_value_i = torch.cat((quant_value_i, back_high_bits_expand), dim=1)
|
||||
value_temp = quant_value_i.reshape(head_num, -1, 2, head_size)
|
||||
quant_value_flat = value_temp.permute(0, 1, 3, 2).flatten()
|
||||
v0 = quant_value_flat[0::2].to(torch.uint8)
|
||||
v1 = quant_value_flat[1::2].to(torch.uint8)
|
||||
vp = (v1 << 4) + (v0 & 0x0F)
|
||||
quant_value_i = vp.to(torch.int8).reshape(head_num, -1, head_size)
|
||||
value_cache_i[...] = quant_value_i
|
||||
return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale)
|
||||
|
||||
def int8_to_int4(self,
|
||||
input):
|
||||
input_flat = input.flatten()
|
||||
size = input_flat.size(0)
|
||||
output = torch.zeros(size * 2, dtype=torch.int8, device=input.device)
|
||||
high = input_flat >> 4
|
||||
low = input_flat << 4
|
||||
low = low >> 4
|
||||
output[0::2] = low
|
||||
output[1::2] = high
|
||||
return output
|
||||
|
||||
def test_quant_to_linear_cache(self):
|
||||
test_cases = 100
|
||||
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list *= 16
|
||||
cache_memory_len_list = torch.randint(low=8, high=512, size=(test_cases, ), dtype=torch.int32)
|
||||
cache_memory_len_list = cache_memory_len_list * 2
|
||||
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_map = [torch.half, torch.bfloat16, torch.float]
|
||||
quant_bit_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
quant_bit_map = [4, 8]
|
||||
|
||||
for i in range(test_cases):
|
||||
batch_size = bs_list[i].item()
|
||||
invalid_batch = batch_size // 10
|
||||
num_heads = num_heads_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
group_size_factors = [i for i in range(4, head_size + 1) if head_size % i == 0] # group_size should > 1
|
||||
group_size = random.choice(group_size_factors)
|
||||
cache_memory_len = cache_memory_len_list[i].item()
|
||||
packed = packed_list[i].item()
|
||||
dtype = dtype_map[dtype_list[i]]
|
||||
quant_bit = quant_bit_map[quant_bit_list[i]]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name and dtype == torch.bfloat16:
|
||||
dtype = torch.half
|
||||
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
|
||||
|
||||
print("case num={}, batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={}, quant_bit={}, group_size={} testing...".format(
|
||||
i, batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype, quant_bit, group_size))
|
||||
|
||||
max_bs = batch_size + 1
|
||||
context_lens = torch.randint(size=(batch_size, ), low=1,
|
||||
high=cache_memory_len // 4,
|
||||
dtype=torch.int32, device='mlu')
|
||||
context_lens = context_lens * 2
|
||||
max_context_len = context_lens.max().item()
|
||||
max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch_size, ),
|
||||
low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch_size, ), low=0,
|
||||
high=(cache_memory_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid
|
||||
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
if packed > 0:
|
||||
key = torch.randn((total_seqlen, num_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
value = torch.randn((total_seqlen, num_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
key = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
value = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
key = key.to(dtype)
|
||||
value = value.to(dtype)
|
||||
if quant_bit == 8 and group_size == head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
|
||||
if quant_bit == 8 and group_size != head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
if quant_bit == 4 and group_size == head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
|
||||
if quant_bit == 4 and group_size != head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
ref_key_cache = key_cache.clone()
|
||||
ref_value_cache = value_cache.clone()
|
||||
ref_key_cache_scale = key_cache_scale.clone()
|
||||
ref_value_cache_scale = value_cache_scale.clone()
|
||||
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
|
||||
|
||||
if packed > 0:
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale,
|
||||
ref_value_cache_scale, cu_context_lens, max_context_len,
|
||||
packed > 0, None, cache_bs_id, cache_seq_offsets,
|
||||
quant_bit)
|
||||
ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, cu_context_lens, max_context_len,
|
||||
packed > 0, None, cache_bs_id, cache_seq_offsets,
|
||||
quant_bit)
|
||||
else:
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale,
|
||||
ref_value_cache_scale, context_lens, max_context_len,
|
||||
packed > 0, context_seq_offsets, cache_bs_id,
|
||||
cache_seq_offsets, quant_bit)
|
||||
ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, context_lens, max_context_len,
|
||||
packed > 0, context_seq_offsets, cache_bs_id,
|
||||
cache_seq_offsets, quant_bit)
|
||||
if quant_bit == 8:
|
||||
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0.003,
|
||||
"key_cache must equal ref_key_cache", True, True, True, True)
|
||||
self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0.003,
|
||||
"value_cache must equal ref_value_cache", True, True, True, True)
|
||||
else:
|
||||
key_cache_int8 = self.int8_to_int4(key_cache)
|
||||
ref_key_cache_int8 = self.int8_to_int4(ref_key_cache)
|
||||
value_cache_int8 = self.int8_to_int4(value_cache)
|
||||
ref_value_cache_int8 = self.int8_to_int4(ref_value_cache)
|
||||
|
||||
key_cache_diff = (key_cache_int8.cpu() - ref_key_cache_int8.cpu()).abs()
|
||||
assert torch.max(key_cache_diff) < 2, "ref_key_cache must equal key_cache or absolute values differ by 1 due to round_mode!"
|
||||
value_cache_diff = (value_cache_int8.cpu() - ref_value_cache_int8.cpu()).abs()
|
||||
assert torch.max(value_cache_diff) < 2, "ref_value_cache must equal value_cache or absolute values differ by 1 due to round_mode!"
|
||||
|
||||
self.assertTensorsEqual(key_cache_scale.cpu().float(), ref_key_cache_scale.cpu().float(), 0.003,
|
||||
"key_cache_scale must equal ref_key_cache_scale", True, True, True, True)
|
||||
self.assertTensorsEqual(value_cache_scale.cpu().float(), ref_value_cache_scale.cpu().float(), 0.003,
|
||||
"value_cache_scale must equal ref_value_cache_scale", True, True, True, True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len, dtype, quant_bit = 4, 8, 64, 128, torch.float16, 8
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args)
|
||||
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype, quant_bit)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestQuantToLinearCache))
|
||||
@@ -0,0 +1,213 @@
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
import copy
|
||||
|
||||
class TestQuantToPagedCache(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
k = create_tensor_from_dic(dic['k'], is_uniform=True, low=-1, high=1)
|
||||
v = create_tensor_from_dic(dic['v'], is_uniform=True, low=-0.25, high=0.25)
|
||||
k_cache = create_tensor_from_dic(dic['k_cache'])
|
||||
v_cache = create_tensor_from_dic(dic['v_cache'])
|
||||
k_cache_quant_scale = create_tensor_from_dic(dic['k_cache_quant_scale'])
|
||||
v_cache_quant_scale = create_tensor_from_dic(dic['v_cache_quant_scale'])
|
||||
slot_mapping = dic['slot_mapping']['data']
|
||||
self.launch(k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping)
|
||||
|
||||
def launch(self, *args):
|
||||
v = args[1]
|
||||
args_bak = copy.deepcopy(args)
|
||||
torch_out = self.op_impl_base(*args_bak)
|
||||
tmo_out = ops.quant_to_paged_cache(*args)
|
||||
self.assertTensorsEqual(torch_out[0].cpu().float(), tmo_out[0].cpu().float(), 9e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_out[1].cpu().float(), tmo_out[1].cpu().float(), 9e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
if v is not None:
|
||||
self.assertTensorsEqual(torch_out[2].cpu().float(), tmo_out[2].cpu().float(), 3e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_out[3].cpu().float(), tmo_out[3].cpu().float(), 3e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping = args
|
||||
# input_fp[head_num, head_size]
|
||||
def quant(input_fp: torch.Tensor):
|
||||
input_fp32 = input_fp.to(torch.float32)
|
||||
max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True)
|
||||
scale = max_value / 127.0
|
||||
scaled_input = input_fp32 / scale
|
||||
return scaled_input.to(torch.int8), scale[..., 0]
|
||||
|
||||
tokens_num = k.shape[0] # [token_num, head_num, head_size]
|
||||
block_size = k_cache.shape[2]
|
||||
for i in range(tokens_num):
|
||||
if slot_mapping[i] >= 0:
|
||||
key_i = k[i]
|
||||
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
key_cache_i = k_cache[block_id, :, block_offset, :]
|
||||
key_cache_scale_i = k_cache_quant_scale[block_id, :, block_offset]
|
||||
quant_key_i, key_scale_i = quant(key_i)
|
||||
key_cache_i[...] = quant_key_i
|
||||
key_cache_scale_i[...] = key_scale_i
|
||||
if v is not None:
|
||||
value_i = v[i]
|
||||
value_cache_i = v_cache[block_id, :, block_offset, :]
|
||||
value_cache_scale_i = v_cache_quant_scale[block_id, :, block_offset]
|
||||
quant_value_i, value_scale_i = quant(value_i)
|
||||
value_cache_i[...] = quant_value_i
|
||||
value_cache_scale_i[...] = value_scale_i
|
||||
return (k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale) if v is not None else (k_cache, k_cache_quant_scale)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
|
||||
def test_quant_to_paged_cache(self):
|
||||
token_nums = random.randint(1, 2048)
|
||||
head_num_kv = random.randint(1, 128)
|
||||
head_size = random.randint(1, 1024)
|
||||
block_size = random.randint(1, 50)
|
||||
min_blocks = (int)((token_nums + block_size - 1) / block_size)
|
||||
block_nums = min(min_blocks + 10, 2 * min_blocks)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = random.sample(range(num_slots), token_nums)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
|
||||
slot_mapping[-1] = -1 # test mask
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
only_quant_key_list = [True, False]
|
||||
|
||||
for _ in range(100):
|
||||
print("test_quant_to_paged_cache...")
|
||||
dtype = random.choice(dtype_list)
|
||||
only_quant_key = random.choice(only_quant_key_list)
|
||||
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
|
||||
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
|
||||
if only_quant_key:
|
||||
value, value_cache, value_cache_scale = None, None, None
|
||||
else:
|
||||
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
|
||||
value_cache = torch.zeros_like(key_cache)
|
||||
value_cache_scale = torch.zeros_like(key_cache_scale)
|
||||
self.launch(key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_large_tensor(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
print('quant_to_paged_cache: test_large_tensor')
|
||||
head_num_kv = 16
|
||||
head_size = 128
|
||||
token_nums = 20
|
||||
block_size = 16
|
||||
block_nums = ((2**32 - 1) // 1 // head_num_kv // head_size // block_size)
|
||||
num_slots = block_nums * block_size
|
||||
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
|
||||
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
|
||||
key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache_torch = torch.zeros_like(key_cache_torch)
|
||||
key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
|
||||
value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch)
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
|
||||
key_cache_tmo = torch.zeros_like(key_cache_torch)
|
||||
value_cache_tmo = torch.zeros_like(value_cache_torch)
|
||||
key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch)
|
||||
value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch)
|
||||
self.op_impl_base(key, value, key_cache_torch, value_cache_torch, key_cache_scale_torch, value_cache_scale_torch, slot_mapping)
|
||||
ops.quant_to_paged_cache(key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping)
|
||||
self.assertTensorsEqual(key_cache_torch.cpu(), key_cache_tmo.cpu(), 1)
|
||||
self.assertTensorsEqual(value_cache_torch.cpu(), value_cache_tmo.cpu(), 1)
|
||||
self.assertTensorsEqual(key_cache_scale_torch.cpu().float(), key_cache_scale_tmo.cpu().float(), 3e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(value_cache_scale_torch.cpu().float(), value_cache_scale_tmo.cpu().float(), 3e-3,
|
||||
use_MSE=True, use_RAE=True)
|
||||
|
||||
block_nums = (2**32 // 1 // head_num_kv // head_size // block_size)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
|
||||
key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache_torch = torch.zeros_like(key_cache_torch)
|
||||
key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
|
||||
value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch)
|
||||
key_cache_tmo = torch.zeros_like(key_cache_torch)
|
||||
value_cache_tmo = torch.zeros_like(value_cache_torch)
|
||||
key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch)
|
||||
value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch)
|
||||
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.quant_to_paged_cache,
|
||||
key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
token_nums = random.randint(1, 2048)
|
||||
head_num_kv = random.randint(1, 128)
|
||||
head_size = random.randint(1, 1024)
|
||||
block_size = random.randint(1, 50)
|
||||
min_blocks = (int)((token_nums + block_size - 1) / block_size)
|
||||
block_nums = min(min_blocks + 10, 2 * min_blocks)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = random.sample(range(num_slots), token_nums)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
|
||||
slot_mapping[-1] = -1 # test mask
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
dtype = random.choice(dtype_list)
|
||||
|
||||
print("quant_to_paged_cache: test_prevent...")
|
||||
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
|
||||
value = None
|
||||
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache = torch.zeros_like(key_cache)
|
||||
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
|
||||
value_cache_scale = torch.zeros_like(key_cache_scale)
|
||||
self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().",
|
||||
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, slot_mapping)
|
||||
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
|
||||
value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(1, token_nums, token_nums * head_num_kv))
|
||||
self.assertException("v last dim must be contiguous.",
|
||||
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, slot_mapping)
|
||||
value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(head_size, head_num_kv, 1))
|
||||
self.assertException("v second dim must be contiguous.",
|
||||
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, slot_mapping)
|
||||
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
|
||||
def test_inductor(self):
|
||||
token_nums = 20
|
||||
head_num_kv = 30
|
||||
head_size = 20
|
||||
block_size = 10
|
||||
min_blocks = (int)((token_nums + block_size - 1) / block_size)
|
||||
block_nums = min(min_blocks + 10, 2 * min_blocks)
|
||||
num_slots = block_nums * block_size
|
||||
slot_mapping = random.sample(range(num_slots), token_nums)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
|
||||
slot_mapping[-1] = -1 # test mask
|
||||
|
||||
dtype = torch.half
|
||||
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
|
||||
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
|
||||
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
|
||||
value_cache = torch.zeros_like(key_cache)
|
||||
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
|
||||
value_cache_scale = torch.zeros_like(key_cache_scale)
|
||||
args = (key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_paged_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestQuantToPagedCache))
|
||||
46
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py
Executable file
46
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py
Executable file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as tmo
|
||||
from common_utils import *
|
||||
import random
|
||||
|
||||
|
||||
class TestQuantizeOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
x, smooth, zero = args
|
||||
return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8)
|
||||
|
||||
def test_random_case(self):
|
||||
torch.manual_seed(0)
|
||||
case_list = set()
|
||||
while(len(case_list) < 100):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
dtype = random.choice(dtype_list)
|
||||
ci = random.randint(1, 4096)
|
||||
co = random.randint(1, 4096)
|
||||
case = (ci, co)
|
||||
if case in case_list:
|
||||
continue
|
||||
else:
|
||||
case_list.add((ci, co))
|
||||
x = torch.randn(ci, co, device="mlu", dtype=dtype)
|
||||
scale = torch.randn(co, device="mlu", dtype=torch.float32)
|
||||
print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True)
|
||||
param = (x, scale, None)
|
||||
tmo_output = tmo.quantize(*param)
|
||||
torch_output = self.op_impl_base(*param)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half)
|
||||
scale = torch.randn(1024, device="mlu", dtype=torch.float32)
|
||||
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
|
||||
args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestQuantizeOp))
|
||||
179
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_linear_cache.py
Executable file
179
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_linear_cache.py
Executable file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import random
|
||||
from common_utils import *
|
||||
from typing import Optional
|
||||
|
||||
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
|
||||
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype)
|
||||
return args[0:10]
|
||||
|
||||
class TestReshapeLinearCache(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
key, value, key_cache, value_cache, context_lengths, max_context_len, packed, \
|
||||
context_seq_offset, cache_bs_id, cache_seqlen_offset = args
|
||||
|
||||
batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0]
|
||||
for i in range(batch_size):
|
||||
if packed:
|
||||
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
|
||||
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i+1] - context_lengths[i]
|
||||
context_seq_offset_i = 0
|
||||
else:
|
||||
key_i = key[i].transpose(1, 0)
|
||||
value_i = value[i].transpose(1, 0) if value is not None else None
|
||||
context_len_i = context_lengths[i]
|
||||
context_seq_offset_i = context_seq_offset[i]
|
||||
cache_bs_id_i = cache_bs_id[i]
|
||||
cache_seqlen_offset_i = cache_seqlen_offset[i]
|
||||
if cache_seqlen_offset_i < 0 or cache_bs_id_i < 0:
|
||||
continue
|
||||
|
||||
key_cache_i = \
|
||||
key_cache[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
|
||||
key_cache_i[...] = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
if value_cache is not None and value is not None:
|
||||
value_cache_i = \
|
||||
value_cache[cache_bs_id_i, :, \
|
||||
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
|
||||
value_cache_i[...] = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
|
||||
return (key_cache, value_cache, ) if value_cache is not None else key_cache
|
||||
|
||||
def test_reshape_linear_cache(self):
|
||||
test_cases = 100
|
||||
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list *= 16
|
||||
cache_memory_len_list = torch.randint(low=16, high=1024, size=(test_cases, ), dtype=torch.int32)
|
||||
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=4, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_map = [torch.int8, torch.half, torch.bfloat16, torch.float]
|
||||
|
||||
for i in range(test_cases):
|
||||
q_heads = 1
|
||||
batch_size = bs_list[i].item()
|
||||
invalid_batch = batch_size // 10
|
||||
num_heads = num_heads_list[i].item()
|
||||
head_size = head_size_list[i].item()
|
||||
cache_memory_len = cache_memory_len_list[i].item()
|
||||
packed = packed_list[i].item()
|
||||
total_heads = q_heads + num_heads * 2
|
||||
dtype = dtype_map[dtype_list[i]]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name and dtype == torch.bfloat16:
|
||||
dtype = torch.half
|
||||
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
|
||||
|
||||
print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={} testing...".format(
|
||||
batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype))
|
||||
|
||||
max_bs = batch_size + 1
|
||||
context_lens = torch.randint(size=(batch_size, ), low=1,
|
||||
high=cache_memory_len // 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
# print(context_lens)
|
||||
max_context_len = context_lens.max().item()
|
||||
max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
|
||||
high=(cache_memory_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1
|
||||
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
if packed > 0:
|
||||
context = torch.randn((total_seqlen, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
|
||||
context = context.to(dtype)
|
||||
cache = cache.to(dtype)
|
||||
ref_cache = cache.clone()
|
||||
key = context[..., q_heads:q_heads + num_heads, :]
|
||||
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
ref_key_cache = ref_cache[0]
|
||||
ref_value_cache = ref_cache[1]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
|
||||
|
||||
if packed > 0:
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, cu_context_lens,
|
||||
max_context_len, packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
ops.reshape_linear_cache(key, value, key_cache, value_cache, cu_context_lens,
|
||||
max_context_len, packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
else:
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, context_lens,
|
||||
max_context_len, packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
ops.reshape_linear_cache(key, value, key_cache, value_cache, context_lens,
|
||||
max_context_len, packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
self.assertTensorsEqual(cache.cpu().float(), ref_cache.cpu().float(), 0, "ref_cache must equal cache",
|
||||
True, True, True, True)
|
||||
|
||||
def test_reshape_linear_key_cache(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len = 2, 2, 16, 128
|
||||
print("[test_reshape_linear_key_cache] batch_size={}, num_heads={}, head_size={}, cache_memory_len={} testing...".format(
|
||||
batch_size, num_heads, head_size, cache_memory_len))
|
||||
|
||||
max_bs = batch_size + 1
|
||||
context_lens = torch.randint(size=(batch_size, ), low=1,
|
||||
high=cache_memory_len // 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
max_context_len = context_lens.max().item()
|
||||
max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
|
||||
high=(cache_memory_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
|
||||
context = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
key_cache = torch.randn((max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
|
||||
ref_key_cache = key_cache.clone()
|
||||
key = context[..., :]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
|
||||
self.op_impl_base(key, None, ref_key_cache, None, context_lens,
|
||||
max_context_len, False, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
ops.reshape_linear_cache(key, None, key_cache, None, context_lens,
|
||||
max_context_len, False, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets)
|
||||
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, "ref_cache must equal cache",
|
||||
True, True, True, True)
|
||||
|
||||
|
||||
def test_inductor(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args)
|
||||
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestReshapeLinearCache))
|
||||
141
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_paged_cache.py
Executable file
141
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_paged_cache.py
Executable file
@@ -0,0 +1,141 @@
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
|
||||
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
|
||||
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype)
|
||||
slot_mapping = args[11]
|
||||
args = args[0:4] + [slot_mapping]
|
||||
return args
|
||||
|
||||
class TestReshapePagedCacheOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
k, v, k_cache, v_cache, slot_mapping = args
|
||||
num_tokens = k.shape[0]
|
||||
block_size = k_cache.shape[2]
|
||||
for i in range(num_tokens):
|
||||
if slot_mapping[i] >= 0:
|
||||
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
k_cache[block_id, :, block_offset, :] = k[i]
|
||||
if v is not None:
|
||||
v_cache[block_id, :, block_offset, :] = v[i]
|
||||
return (k_cache, v_cache) if v is not None else k_cache
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
|
||||
def test_reshape_paged_cache(self):
|
||||
test_cases = 100
|
||||
num_tokens_list = torch.randint(low=1, high=1024, size=(test_cases, ), dtype=torch.int32)
|
||||
num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
|
||||
head_size_list *= 16
|
||||
block_size_list = torch.randint(low=1, high=4, size=(test_cases, ), dtype=torch.int32)
|
||||
block_size_list *= 16
|
||||
only_reshape_key_list = [True, False]
|
||||
|
||||
for i in range(test_cases):
|
||||
num_tokens = num_tokens_list[i]
|
||||
num_heads = num_heads_list[i]
|
||||
head_size = head_size_list[i]
|
||||
block_size = block_size_list[i]
|
||||
min_blocks = (num_tokens + block_size - 1) // block_size
|
||||
num_blocks = min(min_blocks + 10, 2 * min_blocks)
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
only_reshape_key = random.choice(only_reshape_key_list)
|
||||
for dtype in dtype_list:
|
||||
print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, testing...".format(
|
||||
num_tokens, num_heads, head_size, num_blocks, block_size), flush=True)
|
||||
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device="mlu")
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
key_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu")
|
||||
value_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu")
|
||||
|
||||
num_slots = num_blocks * block_size
|
||||
slot_mapping = random.sample(range(num_slots.item()), num_tokens.item())
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu")
|
||||
slot_mapping[-1] = -1
|
||||
|
||||
ref_key_cache, ref_value_cache = key_cache.clone(), value_cache.clone()
|
||||
if only_reshape_key:
|
||||
value, ref_value_cache, value_cache = None, None, None
|
||||
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, slot_mapping)
|
||||
ops.reshape_paged_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
if not only_reshape_key:
|
||||
self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_large_tensor(self):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
print("reshape_paged_cache: test_large_tensor")
|
||||
if dtype == torch.float:
|
||||
dtype_size = 4
|
||||
elif dtype == torch.half or dtype == torch.bfloat16:
|
||||
dtype_size = 2
|
||||
head_num = 16
|
||||
head_size = 128
|
||||
token_num = 20
|
||||
block_size = 16
|
||||
block_num = ((2**32 - 1) // dtype_size // head_num // head_size // block_size)
|
||||
k = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu")
|
||||
v = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu")
|
||||
k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
|
||||
v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
|
||||
num_slots = block_num * block_size
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num]
|
||||
ref_key_cache, ref_value_cache = k_cache.clone(), v_cache.clone()
|
||||
self.op_impl_base(k, v, ref_key_cache, ref_value_cache, slot_mapping)
|
||||
ops.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
|
||||
for i in range(block_size):
|
||||
self.assertTensorsEqual(k_cache[:, :, i, :].cpu().float(), ref_key_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
self.assertTensorsEqual(v_cache[:, :, i, :].cpu().float(), ref_value_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
block_num = (2**32 // dtype_size // head_num // head_size // block_size)
|
||||
k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
|
||||
v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
|
||||
num_slots = block_num * block_size
|
||||
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num]
|
||||
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.reshape_paged_cache,
|
||||
k, v, k_cache, v_cache, slot_mapping)
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
k = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu")
|
||||
v = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu")
|
||||
k_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu")
|
||||
v_cache = None
|
||||
slot_mapping = random.sample(range(1024 * 4), 1024)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu")
|
||||
self.assertException("v.has_value() == v_cache.has_value().", ops.reshape_paged_cache,
|
||||
k, v, k_cache, v_cache, slot_mapping)
|
||||
v_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu")
|
||||
v_cache = v_cache.as_strided(size=(1024, 8, 4, 128), stride=(8*4*128, 4* 128, 127, 1))
|
||||
self.assertException("v_cache need be contiguous.", ops.reshape_paged_cache,
|
||||
k, v, k_cache, v_cache, slot_mapping)
|
||||
v_cache = v_cache.contiguous()
|
||||
v = v.as_strided(size=(1024, 8, 128), stride=(1, 1024, 1024 * 8))
|
||||
self.assertException("v last dim must be contiguous.", ops.reshape_paged_cache,
|
||||
k, v, k_cache, v_cache, slot_mapping)
|
||||
v = v.as_strided(size=(1024, 8, 128), stride=(1024, 8, 1))
|
||||
self.assertException("v second dim must be contiguous.", ops.reshape_paged_cache,
|
||||
k, v, k_cache, v_cache, slot_mapping)
|
||||
|
||||
|
||||
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
|
||||
def test_inductor(self):
|
||||
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
|
||||
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_paged_cache, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestReshapePagedCacheOp))
|
||||
158
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_self_attn.py
Executable file
158
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_self_attn.py
Executable file
@@ -0,0 +1,158 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
class SelfAttn(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_weights,
|
||||
qkv_biass,
|
||||
o_weight,
|
||||
o_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
input_size,
|
||||
head_size,
|
||||
query_factor,
|
||||
eps = 1e-5
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert len(qkv_weights) == 3 and len(qkv_biass) == 3, 'length of weights and biass must be 3'
|
||||
self.layernorm = torch.nn.LayerNorm(input_size)
|
||||
self.layernorm.eps = eps
|
||||
self.layernorm.weight = nn.Parameter(norm_weight)
|
||||
self.layernorm.bias = nn.Parameter(norm_bias)
|
||||
self.weights = qkv_weights
|
||||
self.biass = qkv_biass
|
||||
self.o_weight = o_weight
|
||||
self.o_bias = o_bias
|
||||
self.head_size = head_size
|
||||
self.head_num = self.weights[0].size(0) // head_size
|
||||
self.query_factor = query_factor
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
n = input.size(0)
|
||||
t = input.size(1)
|
||||
normed_input = self.layernorm(input)
|
||||
q = F.linear(normed_input, nn.Parameter(self.weights[0]), nn.Parameter(self.biass[0])).view(n, t, self.head_num, self.head_size)
|
||||
k = F.linear(normed_input, nn.Parameter(self.weights[1]), nn.Parameter(self.biass[1])).view(n, t, self.head_num, self.head_size)
|
||||
v = F.linear(normed_input, nn.Parameter(self.weights[2]), nn.Parameter(self.biass[2])).view(n, t, self.head_num, self.head_size)
|
||||
qk = torch.einsum('bthd,bshd->bhts', q, k) * self.query_factor
|
||||
attn = torch.softmax(qk, dim=-1, dtype=v.dtype)
|
||||
qkv = torch.einsum('bhts,bshd->bthd', attn, v).reshape(n, t, -1)
|
||||
output = F.linear(qkv, nn.Parameter(self.o_weight), nn.Parameter(self.o_bias)) + input
|
||||
return output
|
||||
|
||||
class BTSelfAttn(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_weights,
|
||||
qkv_biass,
|
||||
o_weight,
|
||||
o_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
head_size,
|
||||
query_factor,
|
||||
eps = 1e-5
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weights = qkv_weights
|
||||
self.biass = qkv_biass
|
||||
self.o_weight = o_weight
|
||||
self.o_bias = o_bias
|
||||
self.norm_weight = norm_weight
|
||||
self.norm_bias = norm_bias
|
||||
self.head_size = head_size
|
||||
self.query_factor = query_factor
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
n, t = input.size(0), input.size(1)
|
||||
q, k, v = ops.fused_norm_attention_project(input,
|
||||
self.weights[0],
|
||||
self.biass[0],
|
||||
self.weights[1],
|
||||
self.biass[1],
|
||||
self.weights[2],
|
||||
self.biass[2],
|
||||
self.norm_weight,
|
||||
self.norm_bias,
|
||||
self.eps,
|
||||
"nhtc",
|
||||
self.head_size,
|
||||
False)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
attn_out = ops.flash_attention(q, k, v, None, None, None, None, None, t, t,
|
||||
self.query_factor, False, -1, -1, q.dtype).flatten(-2, -1)
|
||||
output = ops.attention_project(attn_out, self.o_weight, self.o_bias, input)
|
||||
return output
|
||||
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
class TestSelfAttn(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
return super().op_impl_base(*args)
|
||||
|
||||
def test_self_attn(self):
|
||||
N, T, input_size, hidden_size, head_size, eps, query_factor = 5, 2048, 512, 768, 64, 1e-5, 0.125
|
||||
for dtype in dtype_list:
|
||||
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
|
||||
N, T, input_size, hidden_size, dtype), flush=True)
|
||||
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
|
||||
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") / 10
|
||||
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
|
||||
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
weights = torch.chunk(weight, 3)
|
||||
biass = torch.chunk(bias, 3)
|
||||
o_weight = torch.randn(input_size, hidden_size, dtype=dtype, device="mlu") / 10
|
||||
o_bias = torch.randn(input_size, dtype=dtype, device="mlu")
|
||||
|
||||
torch_self_attn = SelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias,
|
||||
input_size, head_size, query_factor, eps)
|
||||
tmo_self_attn = BTSelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias,
|
||||
head_size, query_factor, eps)
|
||||
|
||||
# test self_attn
|
||||
print("test self_attn...")
|
||||
torch_out = torch_self_attn(input)
|
||||
tmo_out = tmo_self_attn(input)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.011, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_sd_attention(self):
|
||||
N, Tq, Tk, hq, hk, head_size, query_factor = 2, 4096, 77, 8, 8, 40, 0.125
|
||||
for dtype in dtype_list:
|
||||
q = torch.randn(N, Tq, hq, head_size, dtype=dtype, device="mlu")
|
||||
k = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu")
|
||||
v = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu")
|
||||
|
||||
qk = torch.einsum('bthd,bshd->bhts', q, k) * query_factor
|
||||
attn = torch.softmax(qk, dim=-1, dtype=v.dtype)
|
||||
torch_out = torch.einsum('bhts,bshd->bthd', attn, v).reshape(N, Tq, -1)
|
||||
|
||||
qt = q.transpose(1, 2).contiguous()
|
||||
kt = k.transpose(1, 2).contiguous()
|
||||
vt = v.transpose(1, 2).contiguous()
|
||||
tmo_out = ops.flash_attention(qt.transpose(1, 2),
|
||||
kt.transpose(1, 2),
|
||||
vt.transpose(1, 2),
|
||||
None, None, None, None, None,
|
||||
Tq, Tk, query_factor, False).flatten(-2, -1)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSelfAttn))
|
||||
348
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_session_cache_attn.py
Normal file
348
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_session_cache_attn.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
class TestSessionCacheAttnOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
def dequant_from_cache(key_cache, value_cache, key_cache_scale, value_cache_scale, cache_bs_id, cache_lens,
|
||||
cache_seq_offset, quant_bit, quant_layout, max_cache_len):
|
||||
batch = cache_bs_id.size(0)
|
||||
head_num = key_cache.size(1)
|
||||
head_size = value_cache.size(-1)
|
||||
cache_len = key_cache.size(-2)
|
||||
cache_shape = (batch, max_cache_len, head_num, head_size)
|
||||
key_cache_mem = torch.zeros(cache_shape, dtype=torch.float, device='mlu')
|
||||
value_cache_mem = torch.zeros_like(key_cache_mem)
|
||||
for i in range(batch):
|
||||
batch_id = cache_bs_id[i]
|
||||
cache_offset = 0 if cache_seq_offset is None else cache_seq_offset[i]
|
||||
key_cache_data = key_cache[batch_id] # head_num_kv, cache_mem_len, head_size
|
||||
value_cache_data = value_cache[batch_id]
|
||||
if quant_bit == 4:
|
||||
key_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu")
|
||||
value_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu")
|
||||
key_quant_data_int8[...,::2] = key_cache_data << 4 >> 4
|
||||
key_quant_data_int8[...,1::2] = key_cache_data >> 4
|
||||
key_quant_data_fp32 = key_quant_data_int8.clone().to(torch.float)
|
||||
value_quant_data_int8[:,0::2,:] = value_cache_data << 4 >> 4
|
||||
value_quant_data_int8[:,1::2,:] = value_cache_data >> 4
|
||||
value_quant_data_fp32 = value_quant_data_int8.clone().to(torch.float)
|
||||
else:
|
||||
key_quant_data_fp32 = key_cache_data.clone().to(torch.float)
|
||||
value_quant_data_fp32 = value_cache_data.clone().to(torch.float)
|
||||
|
||||
if quant_layout == 'per_channel':
|
||||
key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale[..., None, :]
|
||||
value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale[..., None, :]
|
||||
else: # per token
|
||||
key_cache_scale_data = key_cache_scale[batch_id]
|
||||
value_cache_scale_data = value_cache_scale[batch_id]
|
||||
key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale_data[..., None]
|
||||
value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale_data[..., None]
|
||||
key_quant_data_fp32 = key_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :] #cut vaild cache
|
||||
value_quant_data_fp32 = value_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :]
|
||||
# head_num_kv, max_cache_len, headsize-> max_cache_len, head_num_kv, headsize
|
||||
key_cache_mem[i] = key_quant_data_fp32.transpose(0,1)
|
||||
value_cache_mem[i] = value_quant_data_fp32.transpose(0,1)
|
||||
return key_cache_mem, value_cache_mem
|
||||
def scale_dot_attn(q_sess, key_cache, value_cache, cu_seq_lens_q, cu_seq_lens_cache, alibi_slope, is_causal, softmax_scale):
|
||||
# cache_shape: [batch, cache_len, head_kv,head_size]
|
||||
batch = cu_seq_lens_q.size(0) - 1
|
||||
head_num_q = q_sess.size(-2)
|
||||
head_num_kv = key_cache.size(-2)
|
||||
assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0
|
||||
group = head_num_q // head_num_kv
|
||||
inf = 1e6
|
||||
device= 'mlu'
|
||||
out_list = []
|
||||
for i in range(batch):
|
||||
q = q_sess[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...]
|
||||
k = key_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...] # [cache_len, head_num_kv, head_size]
|
||||
v = value_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...]
|
||||
k = torch.repeat_interleave(k, group, dim=1) #[cache_len, head_num_q, head_size]
|
||||
v = torch.repeat_interleave(v, group, dim=1)
|
||||
qk = torch.einsum('qhd,khd->hqk', q, k) * softmax_scale
|
||||
seq_q, seq_k = q.size(0), k.size(0)
|
||||
if alibi_slope is not None:
|
||||
slope = alibi_slope.reshape(1, head_num_q, 1, 1)
|
||||
slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device)
|
||||
if is_causal:
|
||||
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device)
|
||||
slope_bias = relative_pos * slope
|
||||
else:
|
||||
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
|
||||
col_idx = torch.arange(seq_k, dtype=torch.long)
|
||||
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device)
|
||||
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
|
||||
qk += (slope_bias.squeeze(0))
|
||||
if is_causal:
|
||||
assert seq_q <= seq_k, "seq_q <= seq_k if causal=True"
|
||||
zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu")
|
||||
tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1)
|
||||
mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k)
|
||||
qk += mask
|
||||
attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype)
|
||||
qkv = torch.einsum('hqk,khd->qhd', attn, v)
|
||||
out_list.append(qkv)
|
||||
output = torch.cat(out_list, dim=0)
|
||||
return output
|
||||
|
||||
q_sess, k_sess, v_sess, key_cache1, value_cache1, key_cache_scale1, value_cache_scale1, cache_lens1,\
|
||||
cache_seq_offset1, quant_bit1, quant_layout1, max_cache_len1,\
|
||||
key_cache2, value_cache2, key_cache_scale2, value_cache_scale2, cache_lens2,\
|
||||
cache_seq_offset2, quant_bit2, quant_layout2, max_cache_len2,\
|
||||
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale = args
|
||||
#1. dequant cache
|
||||
#input cache shape [max_batch, head_num_kv, cache_mem_len, head_size]
|
||||
#output cache shape [batch, max_cache_len, head_num_kv, head_size]
|
||||
key_cache_mem1, value_cache_mem1 = dequant_from_cache(key_cache1, value_cache1, key_cache_scale1,
|
||||
value_cache_scale1, cache_bs_id, cache_lens1, cache_seq_offset1,
|
||||
quant_bit1, quant_layout1, max_cache_len1)
|
||||
key_cache_mem2, value_cache_mem2 = None, None
|
||||
if key_cache2 is not None:
|
||||
key_cache_mem2, value_cache_mem2 = dequant_from_cache(key_cache2, value_cache2, key_cache_scale2,
|
||||
value_cache_scale2, cache_bs_id, cache_lens2, cache_seq_offset2,
|
||||
quant_bit2, quant_layout2, max_cache_len2)
|
||||
#2. concat cache
|
||||
batch = cache_bs_id.size(0)
|
||||
head_num_kv = k_sess.size(-2)
|
||||
head_size = k_sess.size(-1)
|
||||
# concat cache1 and cache2
|
||||
if key_cache2 is not None:
|
||||
cache_lens = cache_lens1 + cache_lens2
|
||||
max_cache_len = max_cache_len1 + max_cache_len2
|
||||
key_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu')
|
||||
value_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu')
|
||||
for i in range(batch): #concat
|
||||
len1 = cache_lens1[i]
|
||||
len2 = cache_lens2[i]
|
||||
key_cache_mem[i, :len1, ...] = key_cache_mem1[i, :len1, ...]
|
||||
key_cache_mem[i, len1:len1+len2, ...] = key_cache_mem2[i, :len2, ...]
|
||||
value_cache_mem[i, :len1, ...] = value_cache_mem1[i, :len1, ...]
|
||||
value_cache_mem[i, len1:len1+len2, ...] = value_cache_mem2[i, :len2, ...]
|
||||
else:
|
||||
key_cache_mem, value_cache_mem = key_cache_mem1, value_cache_mem1
|
||||
cache_lens = cache_lens1
|
||||
#concat mem cache and sess cache
|
||||
concat_cache_lens = cache_lens + sess_lens
|
||||
cu_seq_lens_cache = torch.zeros((batch+1), dtype=torch.int32)
|
||||
cu_seq_lens_cache[1:] = torch.cumsum(concat_cache_lens, dim=-1)
|
||||
total_cache_len = cu_seq_lens_cache[-1]
|
||||
concate_key_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu')
|
||||
concate_value_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu')
|
||||
for i in range(batch): #concat
|
||||
mem_len = cache_lens[i]
|
||||
sess_len = sess_lens[i]
|
||||
off = cu_seq_lens_cache[i]
|
||||
concate_key_cache[off:off+mem_len, ...] = key_cache_mem[i, :mem_len, ...]
|
||||
concate_key_cache[off+mem_len:off+mem_len+sess_len, ...] = k_sess[i, :sess_len, ...]
|
||||
concate_value_cache[off:off+mem_len, ...] = value_cache_mem[i, :mem_len, ...]
|
||||
concate_value_cache[off+mem_len:off+mem_len+sess_len, ...] = v_sess[i, :sess_len, ...]
|
||||
#3. attn
|
||||
attn_out = scale_dot_attn(q_sess, concate_key_cache, concate_value_cache, cu_seq_lens_q, cu_seq_lens_cache, None, is_causal, softmax_scale)
|
||||
return attn_out, concate_key_cache, concate_value_cache
|
||||
|
||||
def test_session_cache_attn_quant_kv(self):
|
||||
max_batch = 64
|
||||
cache_mem_len = 4096
|
||||
head_num_kv = 1
|
||||
head_num_q = 8
|
||||
head_size = 128
|
||||
batch = 32
|
||||
max_sess_len = 100
|
||||
max_cache_len = 3072
|
||||
dtype_list=[torch.float16]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
quant_bit_list = [4, 8]
|
||||
quant_layout_list = ['per_token', 'per_channel']
|
||||
is_causal_list = [True, False]
|
||||
arg = product(is_causal_list, dtype_list, quant_bit_list, quant_layout_list)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
for is_causal, dtype, quant_bit, quant_layout in arg:
|
||||
print(f"is_causal: {is_causal}, dtype: {dtype}, quant_bit: {quant_bit}, quant_layout: {quant_layout}")
|
||||
if quant_bit == 8:
|
||||
cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size)
|
||||
cache_shape_value = cache_shape_key
|
||||
else:
|
||||
cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size//2)
|
||||
cache_shape_value = (max_batch, head_num_kv, cache_mem_len//2, head_size)
|
||||
if quant_layout == "per_token":
|
||||
scale_shape = (max_batch, head_num_kv, cache_mem_len)
|
||||
quant_mode = 1
|
||||
else: # per channel
|
||||
scale_shape = (head_num_kv, head_size)
|
||||
quant_mode = 0
|
||||
key_cache = torch.randint(-127, 128, cache_shape_key, device='mlu').to(torch.int8)
|
||||
value_cache = torch.randint(-127, 128, cache_shape_value, device='mlu').to(torch.int8)
|
||||
key_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float)
|
||||
value_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float)
|
||||
key_cache_scale = torch.fill(key_cache_scale, 0.01)
|
||||
value_cache_scale = torch.fill(value_cache_scale, 0.01)
|
||||
cache_lens = torch.randint(1, max_cache_len + 1, (batch,), dtype=torch.int32, device='mlu')
|
||||
sess_lens = torch.randint(1, max_sess_len + 1, (batch,), dtype=torch.int32, device='mlu')
|
||||
context_lens = cache_lens + sess_lens
|
||||
max_cache_len_new = torch.max(context_lens)
|
||||
cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32)
|
||||
cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1)
|
||||
total_sess_len = cu_seq_lens_q[-1]
|
||||
cu_seq_lens_q=cu_seq_lens_q.mlu()
|
||||
cache_bs_id = random.sample([*range(0, max_batch)], batch)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables
|
||||
mem_cache_seq_offset = torch.randint(0, cache_mem_len-max_cache_len, (batch,), dtype=torch.int32, device='mlu')
|
||||
|
||||
q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu')
|
||||
k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
|
||||
v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
|
||||
#fill sess cache
|
||||
cache_seq_offset = torch.zeros((batch + 1), dtype=torch.int32)
|
||||
cache_seq_offset[1:]= torch.cumsum(context_lens, dim=-1)
|
||||
cache_seq_offset = cache_seq_offset.mlu()
|
||||
context_seq_offset = cache_seq_offset[:-1]
|
||||
sess_seq_offset = context_seq_offset + cache_lens
|
||||
total_mem_len = cache_seq_offset[-1]
|
||||
cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C
|
||||
key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
|
||||
value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
|
||||
for i in range(batch):
|
||||
offset = sess_seq_offset[i]
|
||||
len = sess_lens[i]
|
||||
key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...]
|
||||
value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...]
|
||||
|
||||
# baseline
|
||||
base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(),
|
||||
key_cache, value_cache, key_cache_scale, value_cache_scale, cache_lens, mem_cache_seq_offset,
|
||||
quant_bit, quant_layout, max_cache_len,
|
||||
None, None, None, None, None, None, -1, None, -1,
|
||||
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale)
|
||||
|
||||
#tmo
|
||||
#1. dequant cache
|
||||
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache, value_cache, key_cache_scale,
|
||||
value_cache_scale, cache_lens, max_cache_len, context_seq_offset,
|
||||
cache_bs_id, mem_cache_seq_offset, quant_mode, quant_bit)
|
||||
self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
#2. flash_attn
|
||||
tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset,
|
||||
None, None, max_sess_len, max_cache_len_new, softmax_scale,
|
||||
is_causal, -1, -1, torch.float, False, None, None, None)
|
||||
self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_session_cache_attn_kv_mixquant(self):
|
||||
max_batch = 64
|
||||
cache_mem_len_int4 = 4064
|
||||
cache_mem_len_int8 = 32
|
||||
head_num_kv = 1
|
||||
head_num_q = 8
|
||||
head_size = 128
|
||||
batch = 32
|
||||
max_sess_len = 100
|
||||
max_cache_len_int4 = 3072
|
||||
max_cache_len_int8 = 32
|
||||
dtype_list=[torch.float16]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
is_causal_list = [True, False]
|
||||
arg = product(is_causal_list, dtype_list)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
for is_causal, dtype in arg:
|
||||
print(f"is_causal: {is_causal}, dtype: {dtype}")
|
||||
cache_shape_int8 = (max_batch, head_num_kv, cache_mem_len_int8, head_size)
|
||||
cache_shape_int4_key = (max_batch, head_num_kv, cache_mem_len_int4, head_size//2)
|
||||
cache_shape_int4_value = (max_batch, head_num_kv, cache_mem_len_int4//2, head_size)
|
||||
scale_shape_int4 = (max_batch, head_num_kv, cache_mem_len_int4) #per_token
|
||||
quant_mode_int4 = 1
|
||||
scale_shape_int8 = (head_num_kv, head_size) # per channel
|
||||
quant_mode_int8 = 0
|
||||
key_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8)
|
||||
value_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8)
|
||||
key_cache_int4 = torch.randint(-127, 128, cache_shape_int4_key, device='mlu').to(torch.int8)
|
||||
value_cache_int4 = torch.randint(-127, 128, cache_shape_int4_value, device='mlu').to(torch.int8)
|
||||
key_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float)
|
||||
value_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float)
|
||||
key_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float)
|
||||
value_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float)
|
||||
key_cache_scale_int4 = torch.fill(key_cache_scale_int4, 0.01)
|
||||
value_cache_scale_int4 = torch.fill(value_cache_scale_int4, 0.01)
|
||||
key_cache_scale_int8 = torch.fill(key_cache_scale_int8, 0.01)
|
||||
value_cache_scale_int8 = torch.fill(value_cache_scale_int8, 0.01)
|
||||
cache_lens_int4 = torch.randint(1, max_cache_len_int4 + 1, (batch,), dtype=torch.int32, device='mlu')
|
||||
cache_lens_int8 = torch.randint(1, max_cache_len_int8 + 1, (batch,), dtype=torch.int32, device='mlu')
|
||||
sess_lens = torch.randint(max_sess_len-1, max_sess_len, (batch,), dtype=torch.int32, device='mlu')
|
||||
cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32)
|
||||
cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1)
|
||||
total_sess_len = cu_seq_lens_q[-1]
|
||||
cu_seq_lens_q=cu_seq_lens_q.mlu()
|
||||
context_lens = cache_lens_int4 + cache_lens_int8 + sess_lens
|
||||
max_cache_len_new = torch.max(context_lens)
|
||||
cache_bs_id = random.sample([*range(0, max_batch)], batch)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables
|
||||
cache_seq_offset_int4 = torch.randint(0, cache_mem_len_int4-max_cache_len_int4, (batch,), dtype=torch.int32, device='mlu')
|
||||
cache_seq_offset_int8 = None
|
||||
|
||||
q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu')
|
||||
k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
|
||||
v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
|
||||
#fill sess cache
|
||||
cache_seq_offset1 = torch.zeros((batch + 1), dtype=torch.int32)
|
||||
cache_seq_offset2 = torch.zeros((batch + 1), dtype=torch.int32)
|
||||
cache_seq_offset1[1:]= torch.cumsum(context_lens, dim=-1)
|
||||
cache_seq_offset2[:-1] = cache_seq_offset1[:-1] + cache_lens_int4.cpu()
|
||||
cache_seq_offset1 = cache_seq_offset1.mlu()
|
||||
cache_seq_offset2 = cache_seq_offset2.mlu()
|
||||
|
||||
context_seq_offset1 = cache_seq_offset1[:-1]
|
||||
context_seq_offset2 = cache_seq_offset2[:-1]
|
||||
sess_seq_offset = context_seq_offset2 + cache_lens_int8
|
||||
total_mem_len = cache_seq_offset1[-1]
|
||||
cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C
|
||||
key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
|
||||
value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
|
||||
for i in range(batch):
|
||||
offset = sess_seq_offset[i]
|
||||
len = sess_lens[i]
|
||||
key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...]
|
||||
value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...]
|
||||
|
||||
# baseline
|
||||
base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(),
|
||||
key_cache_int4, value_cache_int4, key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4,
|
||||
cache_seq_offset_int4, 4, 'per_token', max_cache_len_int4,
|
||||
key_cache_int8, value_cache_int8, key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8,
|
||||
cache_seq_offset_int8, 8, 'per_channel', max_cache_len_int8,
|
||||
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale)
|
||||
#tmo
|
||||
#1. dequant cache
|
||||
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int4, value_cache_int4,
|
||||
key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4, max_cache_len_int4,
|
||||
context_seq_offset1, cache_bs_id, cache_seq_offset_int4, quant_mode_int4, 4)
|
||||
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int8, value_cache_int8,
|
||||
key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8, max_cache_len_int8,
|
||||
context_seq_offset2, cache_bs_id, cache_seq_offset_int8, quant_mode_int8, 8)
|
||||
|
||||
self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
#2. flash_attn
|
||||
tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset1,
|
||||
None, None, max_sess_len, max_cache_len_new, softmax_scale,
|
||||
is_causal, -1, -1, torch.float, False, None, None, None)
|
||||
self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSessionCacheAttnOp))
|
||||
454
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_cached_kv_attn.py
Executable file
454
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_cached_kv_attn.py
Executable file
@@ -0,0 +1,454 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
def gen_args(batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads):
|
||||
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size)).mlu().half()
|
||||
input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
|
||||
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
max_context_len = int(max(context_lens))
|
||||
if is_pagedattn:
|
||||
block_size = 16
|
||||
else:
|
||||
block_size = max_seqlen
|
||||
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
|
||||
scale_shape = (num_blocks, num_kv_heads, block_size)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
if kv_data_type is not torch.int8:
|
||||
key_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu()
|
||||
value_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu()
|
||||
key_cache_scale = None
|
||||
value_cache_scale = None
|
||||
else:
|
||||
key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu()
|
||||
value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu()
|
||||
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
alibi_slopes = None
|
||||
if has_alibi:
|
||||
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
return (input_q.contiguous(), key_cache, value_cache, torch.empty_like(input_q), block_tables, context_lens, None,
|
||||
key_cache_scale, value_cache_scale, alibi_slopes, max_context_len, -1, -1, softmax_scale, False, -1)
|
||||
|
||||
def gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, kv_dtype, max_seqlen, seq_q, head_num, num_kv_heads):
|
||||
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype = dtype).mlu()
|
||||
input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
|
||||
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
max_context_len = int(max(context_lens))
|
||||
if is_pagedattn:
|
||||
block_size = 16
|
||||
else:
|
||||
block_size = max_seqlen
|
||||
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size)
|
||||
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
|
||||
scale_shape = (num_blocks, num_kv_heads, block_size)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
if kv_dtype is torch.int8:
|
||||
key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu()
|
||||
value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu()
|
||||
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
else:
|
||||
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
|
||||
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
|
||||
key_cache_scale = None
|
||||
value_cache_scale = None
|
||||
alibi_slopes = None
|
||||
if has_alibi:
|
||||
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
return input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len
|
||||
|
||||
class TestSingleQueryAttnOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
q, k_cache, v_cache, out, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, \
|
||||
alibi_slopes, max_contxt_len, windows_size_left, windows_size_right, softmax_scale, return_lse, \
|
||||
kv_cache_quant_bit_size = args
|
||||
base_output = single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale,
|
||||
v_cache_quant_scale, alibi_slopes, windows_size_left, windows_size_right, softmax_scale, return_lse)
|
||||
return base_output
|
||||
|
||||
def test_single_query_attention(self):
|
||||
head_num = 16
|
||||
batch_list = [5, 12]
|
||||
num_kv_heads = 4
|
||||
head_size_list = [(128, 128), (192, 384)]
|
||||
seq_len_list = [512]
|
||||
is_pagedattn_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
is_pagedattn_list = [False]
|
||||
has_alibi_list = [True, False]
|
||||
seq_q_list = [1, 5]
|
||||
window_size_list = [(-1, -1), (10, -1)]
|
||||
data_type_list = [torch.float16, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
data_type_list.append(torch.bfloat16)
|
||||
|
||||
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, seq_len_list, seq_q_list, window_size_list, data_type_list)
|
||||
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, max_seqlen, seq_q, (window_size_left, window_size_right), dtype in args:
|
||||
print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, seq_q {}, window_size_left {},\
|
||||
window_size_right {}, dtype {}, testing...".format(
|
||||
batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, seq_q, window_size_left, window_size_right, dtype))
|
||||
# prepare input
|
||||
params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, dtype, max_seqlen, seq_q, head_num, num_kv_heads)
|
||||
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
torch_output = self.op_impl_base(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
tmo_output1 = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
if seq_q == 1:
|
||||
torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
|
||||
0.0001, use_MSE=True)
|
||||
|
||||
# @unittest.skip("not test")
|
||||
def test_single_query_attention_quantize_kv(self):
|
||||
head_num = 16
|
||||
batch_list = [5, 12]
|
||||
num_kv_heads = 4
|
||||
head_size_list = [(128, 128), (16, 384)]
|
||||
seq_len_list = [512]
|
||||
is_pagedattn_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
is_pagedattn_list = [False]
|
||||
has_alibi_list = [True, False]
|
||||
quant_mode_list = ['per_token', 'per_channel']
|
||||
kv_data_type_list = [torch.int8]
|
||||
seq_q_list = [1, 5]
|
||||
window_size_list = [(-1, -1), (10, -1)]
|
||||
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list)
|
||||
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, kv_data_type, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right) in args:
|
||||
print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, \
|
||||
quant_mode {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format(
|
||||
batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, kv_data_type, quant_mode, seq_q, window_size_left, window_size_right))
|
||||
# prepare input
|
||||
params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads)
|
||||
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
torch_output_contiguous = self.op_impl_base(input_q.contiguous(),
|
||||
key_cache, value_cache, None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale, alibi_slopes,
|
||||
max_context_len, window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
tmo_output_contiguous = ops.single_query_cached_kv_attn(input_q.contiguous(),
|
||||
key_cache, value_cache, None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale, alibi_slopes,
|
||||
max_context_len, window_size_left, window_size_right, softmax_scale)
|
||||
self.assertTensorsEqual(torch_output_contiguous.cpu().float(), tmo_output_contiguous.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
if seq_q == 1:
|
||||
torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
|
||||
None, block_tables, context_lens, key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.005, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), 0.0003, use_MSE=True)
|
||||
|
||||
def test_single_query_attention_int4_kv(self):
|
||||
head_num = 16
|
||||
batch_list = [5, 12]
|
||||
num_kv_heads = 4
|
||||
head_size_list = [(64, 128), (256, 128), (64, 384)]
|
||||
seq_len_list = [512]
|
||||
is_pagedattn_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
is_pagedattn_list = [False]
|
||||
has_alibi_list = [True, False]
|
||||
quant_mode_list = ['per_token', 'per_channel', 'per_token_group']
|
||||
data_type_list = [torch.float, torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
data_type_list.append(torch.bfloat16)
|
||||
kv_data_type = torch.int8
|
||||
seq_q_list = [1, 5]
|
||||
#int4 range
|
||||
quant_bit = 4
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
int_min = -float(2 ** (quant_bit - 1))
|
||||
window_size_list = [(-1, -1), (20, -1)]
|
||||
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list, data_type_list)
|
||||
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right), data_type in args:
|
||||
print("kv4: batch: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, quant_mode {}, max_seqlen: {}, seq_q {}, \
|
||||
window_size_left {}, window_size_right {}, data_type {}, testing...".format(
|
||||
batch, head_size, head_size_v, is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, window_size_left, window_size_right, data_type))
|
||||
# prepare input
|
||||
input_qkv = torch.randn((batch, seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
|
||||
input_q = input_qkv[..., 0:head_num,:]
|
||||
|
||||
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
max_context_len = context_lens.max().item()
|
||||
if max_seqlen % 2 == 1:
|
||||
max_seqlen += 1
|
||||
if is_pagedattn:
|
||||
block_size = 16
|
||||
else:
|
||||
block_size = max_seqlen
|
||||
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size)
|
||||
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
|
||||
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, int(head_size/2))
|
||||
cache_shape_v_int4 = (num_blocks, num_kv_heads, int(block_size/2), head_size_v)
|
||||
cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size_v, int(block_size/2))
|
||||
|
||||
if quant_mode == "per_channel":
|
||||
scale_shape_k = (num_kv_heads, head_size)
|
||||
scale_shape_v = (num_kv_heads, head_size_v)
|
||||
elif quant_mode == "per_token":
|
||||
scale_shape_k = (num_blocks, num_kv_heads, block_size)
|
||||
scale_shape_v = (num_blocks, num_kv_heads, block_size)
|
||||
elif quant_mode == "per_token_group":
|
||||
scale_shape_k = (num_blocks, num_kv_heads, block_size, 1) #group_size = head_size
|
||||
scale_shape_v = (num_blocks, num_kv_heads, block_size, 1)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
|
||||
key_cache = torch.zeros(cache_shape_k).uniform_(int_min, int_max).to(kv_data_type).mlu()
|
||||
value_cache = torch.zeros(cache_shape_v).uniform_(int_min, int_max).to(kv_data_type).mlu()
|
||||
key_cache_scale = torch.randn(size=scale_shape_k, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(size=scale_shape_v, dtype=torch.float32).mlu()
|
||||
key_cache_view = key_cache.reshape(-1, head_size)
|
||||
value_cache_view = value_cache.transpose(2, 3).reshape(-1, block_size)
|
||||
key_cache_int4 = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4)
|
||||
value_cache_int4 = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous()
|
||||
|
||||
alibi_slopes = None
|
||||
if has_alibi:
|
||||
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
|
||||
torch_output = self.op_impl_base(input_q.contiguous(),
|
||||
key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, 4)
|
||||
tmo_output_contigous = ops.single_query_cached_kv_attn(input_q.contiguous(),
|
||||
key_cache_int4, value_cache_int4,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, 4)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_contigous.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
tmo_output_inplace = torch.empty((batch, seq_q, head_num, head_size_v), dtype=data_type, device="mlu")
|
||||
ops.single_query_cached_kv_attn(input_q, key_cache_int4, value_cache_int4,
|
||||
tmo_output_inplace, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, 4)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_inplace.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
if seq_q == 1:
|
||||
torch_output, torch_lse = self.op_impl_base(input_q,
|
||||
key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, 4)
|
||||
tmo_output1, tmo_lse = ops.single_query_cached_kv_attn(input_q,
|
||||
key_cache_int4, value_cache_int4,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, 4)
|
||||
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
|
||||
def test_single_query_attention_inplace(self):
|
||||
head_num = 16
|
||||
batch_list = [5]
|
||||
num_kv_heads = 4
|
||||
head_size_list = [64]
|
||||
seq_len_list = [512]
|
||||
is_pagedattn_list = [False, True]
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
is_pagedattn_list = [False]
|
||||
has_alibi_list = [True, False]
|
||||
kv_data_type_list = [torch.int8, torch.float16]
|
||||
seq_q_list = [1, 5]
|
||||
window_size_list = [(-1, -1), (20, -1)]
|
||||
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, seq_len_list, seq_q_list, window_size_list)
|
||||
for batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, (window_size_left, window_size_right) in args:
|
||||
print("batch: {}, max_seqlen: {}, head_size: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format(
|
||||
batch, max_seqlen, head_size, is_pagedattn, has_alibi, kv_data_type, seq_q, window_size_left, window_size_right))
|
||||
# prepare input
|
||||
params = gen_params(batch, head_size, head_size, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads)
|
||||
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
tmo_output = torch.empty_like(input_q)
|
||||
tmo_output_contigous = torch.empty_like(input_q)
|
||||
torch_output = self.op_impl_base(input_q, key_cache, value_cache, None,
|
||||
block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, tmo_output,
|
||||
block_tables, context_lens, key_cache_scale,
|
||||
value_cache_scale, alibi_slopes, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale)
|
||||
ops.single_query_cached_kv_attn(input_q.contiguous(), key_cache, value_cache,
|
||||
tmo_output_contigous, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale, alibi_slopes,
|
||||
max_context_len, window_size_left, window_size_right, softmax_scale)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), tmo_output_contigous.cpu().float(),
|
||||
0.000, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
|
||||
# 防呆测试
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
func = ops.single_query_cached_kv_attn
|
||||
batch, seq_q, head_num, num_kv_heads, head_size_qk, head_size_v, max_seqlen, softmax_scale = 5, 1, 8, 3, 64, 128, 512, 0.625
|
||||
dtype = torch.float16
|
||||
input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu()
|
||||
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
max_context_len = int(max(context_lens))
|
||||
block_size = 16
|
||||
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk)
|
||||
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
|
||||
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
|
||||
key_cache_scale = None
|
||||
value_cache_scale = None
|
||||
window_size_left, window_size_right = 10, 10
|
||||
self.assertException("only support windows_size_right < 0 currently.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
window_size_right = -1
|
||||
self.assertException("num_heads need be mutiple of num_kv_heads.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
num_kv_heads = 1
|
||||
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk)
|
||||
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
|
||||
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
|
||||
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
|
||||
self.assertException("illegal quant bit size, only support 4, 8 or -1.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, 16)
|
||||
input1 = torch.randn((batch, seq_q, head_num, head_size_qk * 2), dtype = dtype).mlu()
|
||||
input = input1[..., 0::2]
|
||||
self.assertException("q last two dim need be contiguous.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu()
|
||||
key_cache1 = key_cache[..., :8, :]
|
||||
value_cache1 = value_cache[..., :8, :]
|
||||
self.assertException("k_cache and v_cache need be contiguous.",
|
||||
func, input, key_cache1, value_cache1,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
self.assertException("q_ori need be mlu tensor.",
|
||||
func, input.cpu(), key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
input = torch.randn((batch, 5, head_num, head_size_qk), dtype = dtype).mlu()
|
||||
self.assertException("return lse only support seq_q = 1 currently.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
self.assertException("block_tables type need be torch::kInt32 or torch::kLong.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables.float(), context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
self.assertException("context_lens type need be torch::kInt32.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens.to(torch.int64),
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
self.assertException("context_lens need be contiguous.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens[0::2],
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, False, -1)
|
||||
scale_shape = (num_blocks)
|
||||
key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu()
|
||||
value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu()
|
||||
key_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu()
|
||||
self.assertException("k_cache_quant_scale must be 2d or 3d or 4d.",
|
||||
func, input, key_cache, value_cache,
|
||||
None, block_tables, context_lens,
|
||||
key_cache_scale, value_cache_scale,
|
||||
None, max_context_len,
|
||||
window_size_left, window_size_right, softmax_scale, True, -1)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq_q, head_num, num_kv_heads, head_size, max_seqlen = 1, 5, 16, 16, 128, 512
|
||||
is_pagedattn_list = [False, True] if "MLU3" not in torch.mlu.get_device_name() else [False]
|
||||
has_alibi_list = [True, False]
|
||||
test_flags = product(is_pagedattn_list, has_alibi_list)
|
||||
for is_pagedattn, has_alibi in test_flags:
|
||||
args = gen_args(batch, head_size, is_pagedattn, has_alibi, torch.int8, max_seqlen, seq_q, head_num, num_kv_heads)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.single_query_cached_kv_attn, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSingleQueryAttnOp))
|
||||
@@ -0,0 +1,393 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
def gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len, data_type, quant_bit, quant_mode, is_normal=True):
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
int_min = -float(2 ** (quant_bit - 1))
|
||||
context_lens = torch.randint(seq_q, seq_len + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
if is_normal is False and batch > 3: # replace some batch's context to 0
|
||||
num = batch // 3
|
||||
index = torch.randint(0, batch, (num,))
|
||||
context_lens[index] = 0
|
||||
max_context_len = context_lens.max().item()
|
||||
block_size = 16
|
||||
if is_pagedattn is False:
|
||||
block_size = seq_len
|
||||
num_blocks = batch * ((seq_len + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
|
||||
if quant_mode == "per_token":
|
||||
scale_shape = (num_blocks, num_kv_heads, block_size, 1)
|
||||
else: # per channel
|
||||
scale_shape = (num_kv_heads, head_size)
|
||||
|
||||
if quant_bit == 4:
|
||||
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
|
||||
cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
|
||||
cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size, block_size//2)
|
||||
key_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
value_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
# pre_process int4_kv_cache
|
||||
key_cache_view = key_cache_int8.reshape(-1, head_size)
|
||||
value_cache_view = value_cache_int8.transpose(2, 3).reshape(-1, block_size)
|
||||
key_cache = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4)
|
||||
value_cache = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous()
|
||||
key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
elif quant_bit == 8:
|
||||
key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
elif quant_bit == -1:
|
||||
key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
|
||||
value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
|
||||
key_scale = None
|
||||
value_scale = None
|
||||
else:
|
||||
print("gen case error, quant_bit_lp must be in {-1, 4, 8}")
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
output = [key_cache, value_cache, key_scale, value_scale, context_lens, block_tables]
|
||||
if quant_bit == 4:
|
||||
output.append(key_cache_int8)
|
||||
output.append(value_cache_int8)
|
||||
return output
|
||||
|
||||
def concate_cache_linear(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2,
|
||||
scalek1, scalek2, scalev1, scalev2):
|
||||
if scalek1 is not None:
|
||||
if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size]
|
||||
scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1])
|
||||
scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1])
|
||||
elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
|
||||
scalek1 = scalek1.reshape(*scalek1.shape, 1)
|
||||
scalev1 = scalev1.reshape(*scalev1.shape, 1)
|
||||
cachek1 *= scalek1
|
||||
cachev1 *= scalev1
|
||||
if scalek2 is not None:
|
||||
if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size]
|
||||
scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1])
|
||||
scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1])
|
||||
elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
|
||||
scalek2 = scalek2.reshape(*scalek2.shape, 1)
|
||||
scalev2 = scalev2.reshape(*scalev2.shape, 1)
|
||||
cachek2 *= scalek2
|
||||
cachev2 *= scalev2
|
||||
new_context = context1 + context2
|
||||
seq_len1 = cachek1.shape[2]
|
||||
seq_len2 = cachek2.shape[2]
|
||||
new_max_context = seq_len1 + seq_len2
|
||||
batch = cachek1.shape[0]
|
||||
num_head = cachek1.shape[1]
|
||||
head_size = cachek1.shape[3]
|
||||
new_cache_k = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32)
|
||||
new_cache_v = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32)
|
||||
new_block_table = torch.arange(0, batch)
|
||||
new_block_table = new_block_table.view(batch, 1)
|
||||
for i in range(batch):
|
||||
len1 = context1[i]
|
||||
len2 = context2[i]
|
||||
block_id1 = block_tables1[i]
|
||||
block_id2 = block_tables2[i]
|
||||
new_cache_k[i, :, :len1, :] = cachek1[block_id1, :, :len1, :]
|
||||
new_cache_v[i, :, :len1, :] = cachev1[block_id1, :, :len1, :]
|
||||
new_cache_k[i, :, len1:len1 + len2, :] = cachek2[block_id2, :, :len2, :]
|
||||
new_cache_v[i, :, len1:len1 + len2, :] = cachev2[block_id2, :, :len2, :]
|
||||
return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu()
|
||||
|
||||
# cache1 and cache2 are float
|
||||
def concat_cache_paged(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2,
|
||||
scalek1, scalek2, scalev1, scalev2):
|
||||
batch = context1.shape[0]
|
||||
block_size = cachek1.shape[2]
|
||||
if scalek1 is not None:
|
||||
if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size]
|
||||
scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1])
|
||||
scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1])
|
||||
elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
|
||||
scalek1 = scalek1.reshape(*scalek1.shape, 1)
|
||||
scalev1 = scalev1.reshape(*scalev1.shape, 1)
|
||||
cachek1 *= scalek1
|
||||
cachev1 *= scalev1
|
||||
if scalek2 is not None:
|
||||
if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size]
|
||||
scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1])
|
||||
scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1])
|
||||
elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
|
||||
scalek2 = scalek2.reshape(*scalek2.shape, 1)
|
||||
scalev2 = scalev2.reshape(*scalev2.shape, 1)
|
||||
cachek2 *= scalek2
|
||||
cachev2 *= scalev2
|
||||
new_context = context1 + context2
|
||||
max_num_blocks_per_seq = block_tables1.shape[1] + block_tables1.shape[1]
|
||||
new_cache_k = torch.concat((cachek1, cachek2), dim = 0)
|
||||
new_cache_v = torch.concat((cachev1, cachev2), dim = 0)
|
||||
num_block1 = cachek1.shape[0]
|
||||
new_block_table = torch.zeros(batch, max_num_blocks_per_seq, dtype=torch.int32)
|
||||
for i in range(batch):
|
||||
len1 = context1[i]
|
||||
len2 = context2[i]
|
||||
block_num1 = (len1 + block_size - 1) // block_size
|
||||
block_num2 = (len2 + block_size - 1) // block_size
|
||||
len1_pad = block_num1 * block_size
|
||||
block1 = block_tables1[i]
|
||||
block2 = block_tables2[i]
|
||||
new_block_table[i, :block_num1] = block1[:block_num1]
|
||||
new_block_table[i, block_num1:block_num1 + block_num2] = block2[:block_num2] + num_block1
|
||||
if len1 != len1_pad:
|
||||
reg_block_id = new_block_table[i, block_num1 - 1] # last block of cache1
|
||||
cat_block_id = new_block_table[i, block_num1] # frist block of cache2
|
||||
reg_len = len1 % block_size
|
||||
pad_len = len1_pad - len1
|
||||
new_cache_k[reg_block_id, :, reg_len:, :] = new_cache_k[cat_block_id, :, :pad_len, :]
|
||||
new_cache_v[reg_block_id, :, reg_len:, :] = new_cache_v[cat_block_id, :, :pad_len, :]
|
||||
for j in range(block_num2-1):
|
||||
block_id1 = new_block_table[i, block_num1 + j] # current
|
||||
block_id2 = new_block_table[i, block_num1 + j + 1] # next
|
||||
new_cache_k[block_id1, :, :reg_len, :] = new_cache_k[block_id1, :, pad_len:, :]
|
||||
new_cache_k[block_id1, :, reg_len:, :] = new_cache_k[block_id2, :, :pad_len, :]
|
||||
new_cache_v[block_id1, :, :reg_len, :] = new_cache_v[block_id1, :, pad_len:, :]
|
||||
new_cache_v[block_id1, :, reg_len:, :] = new_cache_v[block_id2, :, :pad_len, :]
|
||||
block_id = new_block_table[i, block_num1 + block_num2 - 1]
|
||||
new_cache_k[block_id, :, :reg_len, :] = new_cache_k[block_id, :, pad_len:, :]
|
||||
new_cache_v[block_id, :, :reg_len, :] = new_cache_v[block_id, :, pad_len:, :]
|
||||
return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu()
|
||||
|
||||
class TestSingleQueryMixedKVAttnOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
q = create_tensor_from_dic(dic['q'])
|
||||
k_cache_lp = create_tensor_from_dic(dic['k_cache_lp'])
|
||||
v_cache_lp = create_tensor_from_dic(dic['v_cache_lp'])
|
||||
k_cache_hp = create_tensor_from_dic(dic['k_cache_hp'])
|
||||
v_cache_hp = create_tensor_from_dic(dic['v_cache_hp'])
|
||||
out = create_tensor_from_dic(dic['out'])
|
||||
block_tables_lp = dic['block_tables_lp']['data']
|
||||
block_tables_hp = dic['block_tables_hp']['data']
|
||||
context_lens_lp = dic['context_lens_lp']['data']
|
||||
context_lens_hp = dic['context_lens_hp']['data']
|
||||
k_cache_quant_scale_lp = create_tensor_from_dic(dic['k_cache_quant_scale_lp'])
|
||||
v_cache_quant_scale_lp = create_tensor_from_dic(dic['v_cache_quant_scale_lp'])
|
||||
k_cache_quant_scale_hp = create_tensor_from_dic(dic['k_cache_quant_scale_hp'])
|
||||
v_cache_quant_scale_hp = create_tensor_from_dic(dic['v_cache_quant_scale_hp'])
|
||||
alibi_slopes = create_tensor_from_dic(dic['alibi_slopes'])
|
||||
max_contxt_len_lp = dic['max_contxt_len_lp']['data']
|
||||
max_contxt_len_hp = dic['max_contxt_len_hp']['data']
|
||||
softmax_scale = dic['softmax_scale']['data']
|
||||
return_lse = dic['return_lse']['data']
|
||||
kv_cache_quant_bit_size_lp = dic['kv_cache_quant_bit_size_lp']['data']
|
||||
kv_cache_quant_bit_size_hp = dic['kv_cache_quant_bit_size_hp']['data']
|
||||
self.launch(q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp,
|
||||
block_tables_hp, context_lens_lp, context_lens_hp, k_cache_quant_scale_lp,
|
||||
v_cache_quant_scale_lp, k_cache_quant_scale_hp, v_cache_quant_scale_hp,
|
||||
alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale,
|
||||
return_lse, kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp)
|
||||
|
||||
def launch(self, *args):
|
||||
torch_output, torch_lse = self.op_impl_base(*args)
|
||||
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(*args)
|
||||
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
|
||||
0.0003, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp, block_tables_hp, \
|
||||
context_lens_lp, context_lens_hp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, k_cache_quant_scale_hp, \
|
||||
v_cache_quant_scale_hp, alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale, return_lse, \
|
||||
kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp = args
|
||||
|
||||
if kv_cache_quant_bit_size_lp == 4:
|
||||
num_blocks, num_kv_heads, block_size_lp, head_size = v_cache_lp.size()
|
||||
block_size = block_size_lp * 2
|
||||
k_cache_lp = UnpackInt4(k_cache_lp).reshape(num_blocks, num_kv_heads, block_size, head_size)
|
||||
v_cache_lp = UnpackInt4(v_cache_lp.transpose(2,3)).reshape(num_blocks, num_kv_heads, head_size, block_size).transpose(2,3)
|
||||
|
||||
torch_output_lp, torch_lse_lp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_lp.float(), v_cache_lp.float(),
|
||||
block_tables_lp, context_lens_lp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, alibi_slopes, -1, -1, softmax_scale, return_lse)
|
||||
torch_output_hp, torch_lse_hp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_hp.float(), v_cache_hp.float(),
|
||||
block_tables_hp, context_lens_hp, k_cache_quant_scale_hp, v_cache_quant_scale_hp, alibi_slopes, -1, -1, softmax_scale, return_lse)
|
||||
torch_output, torch_lse = update_out_and_lse_torch(torch_output_lp, torch_lse_lp, torch_output_hp, torch_lse_hp, None, None, None)
|
||||
return (torch_output, torch_lse) if return_lse else torch_output
|
||||
|
||||
def test_single_query_mixedkv_attention(self):
|
||||
head_num = 16
|
||||
batch = 12
|
||||
num_kv_heads = 4
|
||||
seq_q = 1
|
||||
head_size = 128
|
||||
seq_len_lp = 512
|
||||
seq_len_hp = 128
|
||||
is_pagedattn_list = [True, False]
|
||||
has_alibi_list = [True, False]
|
||||
is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0
|
||||
quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)]
|
||||
data_type_list = [torch.float, torch.float16]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
data_type_list.append(torch.bfloat16)
|
||||
args = product(is_pagedattn_list, has_alibi_list, data_type_list, quant_bit_list, is_normal_list)
|
||||
for is_pagedattn, has_alibi, data_type, quant_bit, is_normal in args:
|
||||
quant_bit_lp, quant_bit_hp = quant_bit
|
||||
print("test separate:{} + {}: batch:{}, seq_len_lp:{}, seq_len_hp:{}, head_size:{}, is_pagedattn:{}, has_alibi:{}, data_type:{}, is_normal:{} ...".format(
|
||||
quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, has_alibi, data_type, is_normal))
|
||||
if is_pagedattn:
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
continue
|
||||
torch.manual_seed(1)
|
||||
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
|
||||
input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
|
||||
#gen k/v cache
|
||||
params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp,
|
||||
"per_token", is_normal)
|
||||
params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp,
|
||||
"per_channel")
|
||||
if quant_bit_lp == 4:
|
||||
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, _, _ = params_lp
|
||||
else:
|
||||
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp
|
||||
key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp
|
||||
max_context_len_lp = context_lens_lp.max().item()
|
||||
max_context_len_hp = context_lens_hp.max().item()
|
||||
|
||||
alibi_slopes = None
|
||||
if has_alibi:
|
||||
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
torch_output, torch_lse = self.op_impl_base(input_q,
|
||||
key_cache_lp, value_cache_lp,
|
||||
key_cache_hp, value_cache_hp,
|
||||
None, #output
|
||||
block_tables_lp, block_tables_hp,
|
||||
context_lens_lp, context_lens_hp,
|
||||
key_scale_lp, value_scale_lp,
|
||||
key_scale_hp, value_scale_hp,
|
||||
alibi_slopes,
|
||||
max_context_len_lp, max_context_len_hp,
|
||||
softmax_scale, True,
|
||||
quant_bit_lp, quant_bit_hp)
|
||||
|
||||
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q,
|
||||
key_cache_lp, value_cache_lp,
|
||||
key_cache_hp, value_cache_hp,
|
||||
None, #output
|
||||
block_tables_lp, block_tables_hp,
|
||||
context_lens_lp, context_lens_hp,
|
||||
key_scale_lp, value_scale_lp,
|
||||
key_scale_hp, value_scale_hp,
|
||||
alibi_slopes,
|
||||
max_context_len_lp, max_context_len_hp,
|
||||
softmax_scale, True,
|
||||
quant_bit_lp, quant_bit_hp)
|
||||
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
|
||||
0.0003, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
|
||||
def test_single_query_mixedkv_attention_concate(self):
|
||||
head_num = 16
|
||||
batch = 12
|
||||
num_kv_heads = 4
|
||||
seq_q = 1
|
||||
head_size = 128
|
||||
seq_len_lp = 512
|
||||
seq_len_hp = 128
|
||||
is_pagedattn_list = [True, False]
|
||||
has_alibi = False #only support without alibi
|
||||
is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0
|
||||
quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)]
|
||||
data_type_list = [torch.float, torch.float16]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
data_type_list.append(torch.bfloat16)
|
||||
args = product(is_pagedattn_list, data_type_list, quant_bit_list, is_normal_list)
|
||||
for is_pagedattn, data_type, quant_bit, is_normal in args:
|
||||
quant_bit_lp, quant_bit_hp = quant_bit
|
||||
print("test concate {} + {}: seq_len_lp: {}, seq_len_hp: {}, is_pagedattn: {}, data_type {}, is_normal {} ...".format(
|
||||
quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, data_type, is_normal))
|
||||
if is_pagedattn:
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
continue
|
||||
torch.manual_seed(1)
|
||||
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
|
||||
input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
|
||||
#gen k/v cache
|
||||
params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp,
|
||||
"per_token", is_normal)
|
||||
params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp,
|
||||
"per_channel")
|
||||
if quant_bit_lp == 4:
|
||||
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, key_cache_lp_torch, value_cache_lp_torch = params_lp
|
||||
else:
|
||||
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp
|
||||
key_cache_lp_torch, value_cache_lp_torch = key_cache_lp, value_cache_lp
|
||||
key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp
|
||||
max_context_len_lp = context_lens_lp.max().item()
|
||||
max_context_len_hp = context_lens_hp.max().item()
|
||||
|
||||
alibi_slopes = None
|
||||
if has_alibi:
|
||||
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
# concate cache
|
||||
if is_pagedattn:
|
||||
concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concat_cache_paged(
|
||||
key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(),
|
||||
value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(),
|
||||
context_lens_lp.cpu(), context_lens_hp.cpu(),
|
||||
block_tables_lp.cpu(), block_tables_hp.cpu(),
|
||||
key_scale_lp.cpu() if key_scale_lp is not None else None,
|
||||
key_scale_hp.cpu() if key_scale_hp is not None else None,
|
||||
value_scale_lp.cpu() if value_scale_lp is not None else None,
|
||||
value_scale_hp.cpu() if value_scale_hp is not None else None)
|
||||
else:
|
||||
concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concate_cache_linear(
|
||||
key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(),
|
||||
value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(),
|
||||
context_lens_lp.cpu(), context_lens_hp.cpu(),
|
||||
block_tables_lp.cpu(), block_tables_hp.cpu(),
|
||||
key_scale_lp.cpu() if key_scale_lp is not None else None,
|
||||
key_scale_hp.cpu() if key_scale_hp is not None else None,
|
||||
value_scale_lp.cpu() if value_scale_lp is not None else None,
|
||||
value_scale_hp.cpu() if value_scale_hp is not None else None)
|
||||
|
||||
torch_output_concat, torch_lse_concat = single_query_cached_kv_attn(input_q.contiguous().float(),
|
||||
concat_cache_k.float(), concat_cache_v.float(), concat_block_tables,
|
||||
concat_context, None, None, alibi_slopes, -1, -1, softmax_scale, True)
|
||||
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q,
|
||||
key_cache_lp, value_cache_lp,
|
||||
key_cache_hp, value_cache_hp,
|
||||
None, #output
|
||||
block_tables_lp, block_tables_hp,
|
||||
context_lens_lp, context_lens_hp,
|
||||
key_scale_lp, value_scale_lp,
|
||||
key_scale_hp, value_scale_hp,
|
||||
alibi_slopes,
|
||||
max_context_len_lp, max_context_len_hp,
|
||||
softmax_scale, True,
|
||||
quant_bit_lp, quant_bit_hp)
|
||||
if is_normal: # only compare lse when context_len_lp is normal, seq=0 case nan-value in lse
|
||||
self.assertTensorsEqual(torch_lse_concat.cpu().float(), tmo_lse.cpu().float(),
|
||||
0.0003, use_MSE=True)
|
||||
self.assertTensorsEqual(torch_output_concat.cpu().float(), tmo_output.cpu().float(),
|
||||
0.003, use_MSE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
return super().test_inductor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSingleQueryMixedKVAttnOp))
|
||||
555
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant.py
Executable file
555
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant.py
Executable file
@@ -0,0 +1,555 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as tmo
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import random
|
||||
import os
|
||||
|
||||
def generate_token_count(num_expert,
|
||||
total_token_count):
|
||||
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \
|
||||
dtype=torch.int32).to(dtype=torch.float32)
|
||||
sum = torch.sum(token_count, dim=-1) * 1.0
|
||||
token_count *= total_token_count / sum.item()
|
||||
token_count = token_count.to(dtype=torch.int32)
|
||||
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
|
||||
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
|
||||
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
|
||||
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
|
||||
cusum_token_count[-1] = total_token_count
|
||||
token_count = cusum_token_count[1:] - cusum_token_count[0:-1]
|
||||
return token_count, cusum_token_count
|
||||
|
||||
def gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
multi_scale,
|
||||
need_gather,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
dtype,
|
||||
device):
|
||||
input = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
|
||||
effective_count = num_tokens * topk if need_gather else num_tokens
|
||||
if multi_scale:
|
||||
token_count, cusum_token_count = generate_token_count(num_expert, effective_count)
|
||||
token_count = token_count.to(device=device)
|
||||
cusum_token_count = cusum_token_count.to(device=device)
|
||||
scale = torch.randn((num_expert, hidden_size), dtype=torch.float32, device=device)
|
||||
else:
|
||||
token_count = None
|
||||
cusum_token_count = None
|
||||
scale = torch.randn((hidden_size), dtype=torch.float32, device=device)
|
||||
|
||||
if need_gather:
|
||||
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32) // topk
|
||||
gather_ids = gather_ids.mlu()
|
||||
else:
|
||||
gather_ids = None
|
||||
|
||||
if expert_size < num_expert and multi_scale:
|
||||
token_count = token_count[start_expert_id:]
|
||||
cusum_token_count = cusum_token_count[start_expert_id:]
|
||||
gather_index_start_position = cusum_token_count[0:1]
|
||||
scale = scale[start_expert_id:]
|
||||
effective_count = cusum_token_count[-1] - cusum_token_count[0]
|
||||
else:
|
||||
gather_index_start_position = None
|
||||
|
||||
if not need_gather:
|
||||
gather_index_start_position = None
|
||||
|
||||
return input, scale, gather_ids, token_count, cusum_token_count, gather_index_start_position, effective_count
|
||||
|
||||
def per_token_smooth_quantize_base(x: torch.Tensor,
|
||||
smooth: torch.Tensor,
|
||||
zero: torch.Tensor = None,
|
||||
token_count: torch.Tensor = None):
|
||||
output_shape = x.size()
|
||||
output_scale_shape = x.size()[0:-1]
|
||||
output, output_scale = QuantByRow(x.flatten(0, -2) * smooth, 8)
|
||||
return output.reshape(output_shape), output_scale.reshape(output_scale_shape)
|
||||
|
||||
def quantize_base(x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
return (x * scale).round().clamp(-128.0, 127.0).to(torch.int8)
|
||||
|
||||
class TestSmoothQuantOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
x = create_tensor_from_dic(dic['x'])
|
||||
smooth = create_tensor_from_dic(dic['smooth'])
|
||||
zero = None if dic['zero']['data'] is None else create_tensor_from_dic(dic['zero'])
|
||||
token_count = None if dic['token_count']['data'] is None else dic['token_count']['data']
|
||||
gather_index = None if dic['gather_index']['data'] is None else dic['gather_index']['data']
|
||||
gather_index_start_position = None if dic['gather_index_start_position']['data'] is None else dic['gather_index_start_position']['data']
|
||||
output = None if dic['output']['data'] is None else create_tensor_from_dic(dic['output'])
|
||||
output_scale = None if dic['output_scale']['data'] is None else create_tensor_from_dic(dic['output_scale'])
|
||||
dynamic_quant = dic['dynamic_quant']['data']
|
||||
self.launch(x, smooth, zero, token_count, gather_index, gather_index_start_position, output, output_scale, dynamic_quant)
|
||||
|
||||
def launch(self, *args):
|
||||
tmo_out = tmo.moe_quantize(*args)
|
||||
out_base = None if args[6] is None else args[6].clone()
|
||||
scale_base = None if args[7] is None else args[7].clone()
|
||||
args = list(args)
|
||||
args[6] = out_base
|
||||
args[7] = scale_base
|
||||
torch_out = self.op_impl_base(*args)
|
||||
|
||||
if args[-1]:
|
||||
self.assertTensorsEqual(torch_out[0].cpu().reshape(-1).float(),
|
||||
tmo_out[0].cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_out[1].cpu().reshape(-1).float(),
|
||||
tmo_out[1].cpu().reshape(-1).float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
else:
|
||||
self.assertTensorsEqual(torch_out.cpu().reshape(-1).float(),
|
||||
tmo_out.cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
x, smooth, zero, token_count, gather_index, gather_index_start_position, \
|
||||
output, output_scale, dynamic_quant = args
|
||||
input = x.to(dtype=torch.float32).cpu()
|
||||
input_scale = smooth.cpu()
|
||||
cusum_token_count = None
|
||||
if token_count is not None:
|
||||
token_count = token_count.cpu()
|
||||
cusum_token_count = torch.zeros(token_count.shape[0] + 1, dtype=torch.int32)
|
||||
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
|
||||
|
||||
if gather_index_start_position is not None:
|
||||
gather_index_start_position = gather_index_start_position.cpu()
|
||||
|
||||
gather_index_start = 0
|
||||
if gather_index_start_position is not None:
|
||||
gather_index_start = gather_index_start_position[0]
|
||||
|
||||
if cusum_token_count is not None:
|
||||
gather_index_end = cusum_token_count[-1] + gather_index_start
|
||||
elif gather_index is not None:
|
||||
gather_index_end = gather_index.numel()
|
||||
else:
|
||||
gather_index_end = input.numel()
|
||||
|
||||
if gather_index is not None:
|
||||
gather_index = gather_index.cpu()
|
||||
gathered_input = input[gather_index[gather_index_start : gather_index_end]]
|
||||
else:
|
||||
gathered_input = input[gather_index_start : gather_index_end]
|
||||
|
||||
if cusum_token_count is not None:
|
||||
for i in range(token_count.shape[0]):
|
||||
gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] *= input_scale[i]
|
||||
else:
|
||||
gathered_input *= input_scale
|
||||
|
||||
if output is None:
|
||||
if not dynamic_quant:
|
||||
return gathered_input.round().clamp(-128.0, 127.0).to(torch.int8), None
|
||||
else:
|
||||
return QuantByRow(gathered_input, 8)
|
||||
else:
|
||||
if not dynamic_quant:
|
||||
output.copy_(gathered_input.round().clamp(-128.0, 127.0).to(torch.int8))
|
||||
output_scale = None
|
||||
else:
|
||||
out, scale = QuantByRow(gathered_input, 8)
|
||||
output_fl = output.flatten()
|
||||
output_fl[:out.numel()].copy_(out.flatten())
|
||||
# output = output_fl
|
||||
output_scale_fl = output_scale.flatten()
|
||||
output_scale_fl[:scale.numel()].copy_(scale.flatten())
|
||||
# output_scale = output_scale_fl
|
||||
return (output, output_scale) if dynamic_quant else (output,)
|
||||
|
||||
def test_random_case(self):
|
||||
torch.manual_seed(333)
|
||||
test_cases = 100
|
||||
num_tokens_list = torch.randint(low=1, high=4096, size=(test_cases, ), dtype=torch.int32)
|
||||
topk_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
|
||||
hidden_size_list = torch.randint(low=128, high=8193, size=(test_cases, ), dtype=torch.int32)
|
||||
num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
expert_size_list = torch.minimum(expert_size_list, num_expert_list)
|
||||
start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32)
|
||||
start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list)
|
||||
start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32))
|
||||
multi_scale_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
need_gather_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
input_with_stride_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dynamic_quant_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
|
||||
dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32)
|
||||
dtypes = [torch.half, torch.half, torch.float32]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtypes += [torch.bfloat16, torch.bfloat16]
|
||||
dtype_list = random.choices(dtypes, k=test_cases)
|
||||
|
||||
device = "mlu"
|
||||
for i in range(test_cases):
|
||||
num_tokens = num_tokens_list[i].item()
|
||||
topk = topk_list[i].item()
|
||||
hidden_size = hidden_size_list[i].item()
|
||||
num_expert = num_expert_list[i].item()
|
||||
expert_size = expert_size_list[i].item()
|
||||
start_expert_id = start_expert_id_list[i].item()
|
||||
multi_scale = multi_scale_list[i].item() == 1
|
||||
need_gather = need_gather_list[i].item() == 1
|
||||
input_with_stride = input_with_stride_list[i].item() == 1
|
||||
dynamic_quant = dynamic_quant_list[i].item() == 1
|
||||
dtype = dtype_list[i]
|
||||
|
||||
if not multi_scale or not torch_mlu.mlu.is_bf16_supported():
|
||||
need_gather = False
|
||||
|
||||
inputs = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
multi_scale,
|
||||
need_gather,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
dtype,
|
||||
device)
|
||||
|
||||
input = inputs[0]
|
||||
input_scale = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
token_count = inputs[3]
|
||||
cusum_token_count = inputs[4]
|
||||
gather_index_start_position = inputs[5]
|
||||
effective_count = inputs[6]
|
||||
|
||||
if input_with_stride:
|
||||
hidden_size = hidden_size - 64
|
||||
input = input[..., hidden_size : ]
|
||||
input_scale = input_scale[..., hidden_size : ].contiguous()
|
||||
|
||||
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
|
||||
"start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, "
|
||||
"dynamic_quant={}, dtype={}, testing...".format(
|
||||
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
|
||||
multi_scale, need_gather, input_with_stride, dynamic_quant, dtype))
|
||||
|
||||
torch_quant, torch_output_scale = self.op_impl_base(input,
|
||||
input_scale,
|
||||
None,
|
||||
token_count,
|
||||
gather_ids,
|
||||
gather_index_start_position,
|
||||
None,
|
||||
None,
|
||||
dynamic_quant)
|
||||
tmo_output_scale = None
|
||||
if dynamic_quant:
|
||||
tmo_output, tmo_output_scale = \
|
||||
tmo.moe_quantize(input, input_scale, None, token_count,
|
||||
gather_ids, gather_index_start_position,
|
||||
None, None, dynamic_quant)
|
||||
else:
|
||||
tmo_output, = \
|
||||
tmo.moe_quantize(input, input_scale, None, token_count,
|
||||
gather_ids, gather_index_start_position,
|
||||
None, None, dynamic_quant)
|
||||
tmo_output = tmo_output[:effective_count]
|
||||
if tmo_output_scale is not None:
|
||||
tmo_output_scale = tmo_output_scale[:effective_count]
|
||||
|
||||
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
|
||||
tmo_output.cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
|
||||
if dynamic_quant:
|
||||
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
|
||||
tmo_output_scale.cpu().reshape(-1).float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_interface(self):
|
||||
channel = 16
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for dtype in dtype_list:
|
||||
input = torch.randn((2, 3, 2, channel), dtype=dtype, device="mlu")
|
||||
input_scale = torch.randn((channel)).float().mlu()
|
||||
|
||||
print("test tmo.per_token_smooth_quantize...")
|
||||
torch_quant, torch_scale = per_token_smooth_quantize_base(input, input_scale)
|
||||
tmo_quant, tmo_scale = tmo.per_token_smooth_quantize(input, input_scale)
|
||||
|
||||
self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
print("test tmo.quantize ...")
|
||||
input_scale = input_scale * 100.0
|
||||
torch_quant = quantize_base(input, input_scale)
|
||||
tmo_quant = tmo.quantize(input, input_scale)
|
||||
self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
print("test tmo.moe_quantize inplace ...")
|
||||
num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \
|
||||
1024, 5, 512, True, True, 32, 4, 4, torch.half, True
|
||||
inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert,
|
||||
expert_size, start_expert_id, dtype, 'mlu')
|
||||
input = inputs[0]
|
||||
input_scale = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
token_count = inputs[3]
|
||||
cusum_token_count = inputs[4]
|
||||
gather_index_start_position = inputs[5]
|
||||
effective_count = inputs[6]
|
||||
tmo_output = torch.empty(num_tokens*topk, hidden_size, dtype=torch.int8, device='mlu')
|
||||
tmo_output_scale = torch.empty(num_tokens*topk, dtype=torch.float, device='mlu')
|
||||
if 'MLU370' not in torch_mlu.mlu.get_device_name():
|
||||
tmo.moe_quantize(input, input_scale, None, token_count,
|
||||
gather_ids, gather_index_start_position,
|
||||
tmo_output, tmo_output_scale, dynamic_quant)
|
||||
tmo_output = tmo_output[:effective_count]
|
||||
tmo_output_scale = tmo_output_scale[:effective_count]
|
||||
torch_output = torch.empty_like(tmo_output)
|
||||
torch_output_scale = torch.empty_like(tmo_output_scale)
|
||||
self.op_impl_base(input, input_scale,
|
||||
None, token_count, gather_ids, gather_index_start_position,
|
||||
torch_output, torch_output_scale, dynamic_quant)
|
||||
self.assertTensorsEqual(torch_output.cpu().reshape(-1).float(),
|
||||
tmo_output.cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
|
||||
tmo_output_scale.cpu().reshape(-1).float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
input = input.reshape(2, 4, 8, 16, -1)
|
||||
input_scale = input_scale[0]
|
||||
tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu')
|
||||
tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu')
|
||||
tmo.moe_quantize(input, input_scale, None, None, None, None,
|
||||
tmo_output, tmo_output_scale, dynamic_quant)
|
||||
torch_quant = torch.empty_like(tmo_output)
|
||||
torch_output_scale = torch.empty_like(tmo_output_scale)
|
||||
self.op_impl_base(input, input_scale,
|
||||
None, None, None, None, torch_quant, torch_output_scale, dynamic_quant)
|
||||
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
|
||||
tmo_output.cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
|
||||
tmo_output_scale.cpu().reshape(-1).float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_prevent(self):
|
||||
num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \
|
||||
1024, 5, 512, True, True, 32, 4, 4, torch.half, True
|
||||
inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert,
|
||||
expert_size, start_expert_id, dtype, 'mlu')
|
||||
input = inputs[0]
|
||||
input_scale = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
token_count = inputs[3]
|
||||
cusum_token_count = inputs[4]
|
||||
gather_index_start_position = inputs[5]
|
||||
effective_count = inputs[6]
|
||||
tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu')
|
||||
tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu')
|
||||
func = tmo.moe_quantize
|
||||
|
||||
self.assertException("input must be mlu tensor.",
|
||||
func, input.cpu(), input_scale, None, None, None, None,
|
||||
None, None, dynamic_quant)
|
||||
self.assertException(None,
|
||||
func, input, input_scale, None, token_count, gather_ids,
|
||||
gather_index_start_position.cpu(), None, None, dynamic_quant)
|
||||
self.assertException("not support output_scale if dynamic_quant = false",
|
||||
func, input, input_scale, None, token_count, gather_ids,
|
||||
gather_index_start_position, None, tmo_output_scale, False)
|
||||
self.assertException("input.dim() == 2 if has gather_index or token_count",
|
||||
func, input.reshape(32, 32, -1), input_scale, None, None, gather_ids,
|
||||
gather_index_start_position, None, None, dynamic_quant)
|
||||
self.assertException("input.dim() >= 2",
|
||||
func, input.reshape(-1), input_scale, None, None, None, None,
|
||||
None, None, dynamic_quant)
|
||||
self.assertException("output.dim() >= 2",
|
||||
func, input, input_scale, None, token_count, gather_ids,
|
||||
gather_index_start_position, tmo_output.reshape(-1), tmo_output_scale, True)
|
||||
self.assertException("input and output must have the same shape",
|
||||
func, input.reshape(2, 512, -1), input_scale, None, None, None, None,
|
||||
tmo_output.reshape(32, 32, -1), None, dynamic_quant)
|
||||
self.assertException("output_scale_shape must be equal to input_shape[0:-1]",
|
||||
func, input.reshape(2, 512, -1), input_scale[0], None, None, None, None,
|
||||
None, tmo_output_scale.reshape(32, 32), dynamic_quant)
|
||||
self.assertException("gather_index must exist if gather_index_start_position has value",
|
||||
func, input, input_scale, None, token_count, None,
|
||||
gather_index_start_position, None, None, True)
|
||||
self.assertException("gather_index.dim() == 1",
|
||||
func, input, input_scale, None, token_count, gather_ids.reshape(1, -1),
|
||||
gather_index_start_position, None, None, True)
|
||||
|
||||
|
||||
def test_perf_case(self):
|
||||
num_tokens_list = [1, 72, 512]
|
||||
topk = 5
|
||||
hidden_size_list = [2048, 4096, 5120, 8192]
|
||||
# [num_expert, start_expert_id, expert_size]
|
||||
expert_options_list = [[8, 0, 8], [32, 24, 8]]
|
||||
multi_scale_list = [True, False]
|
||||
need_gather_list = [True, False]
|
||||
dynamic_quant_list = [True]
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
|
||||
device = 'mlu'
|
||||
args = product(num_tokens_list, hidden_size_list, expert_options_list,\
|
||||
multi_scale_list, need_gather_list, dynamic_quant_list, dtype_list)
|
||||
for num_tokens, hidden_size, expert_options, multi_scale, need_gather, dynamic_quant, dtype in args:
|
||||
num_expert = expert_options[0]
|
||||
start_expert_id = expert_options[1]
|
||||
expert_size = expert_options[2]
|
||||
|
||||
if not multi_scale or not torch_mlu.mlu.is_bf16_supported():
|
||||
need_gather = False
|
||||
|
||||
if not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
|
||||
torch.manual_seed(444)
|
||||
inputs = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
multi_scale,
|
||||
need_gather,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
dtype,
|
||||
device)
|
||||
|
||||
input = inputs[0]
|
||||
input_scale = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
token_count = inputs[3]
|
||||
cusum_token_count = inputs[4]
|
||||
gather_index_start_position = inputs[5]
|
||||
effective_count = inputs[6]
|
||||
|
||||
print("num_tokens={}, hidden_size={}, num_expert={}, expert_size={}, "
|
||||
"start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, "
|
||||
"dynamic_quant={}, dtype={}, testing...".format(
|
||||
num_tokens, hidden_size, num_expert, expert_size, start_expert_id, \
|
||||
multi_scale, need_gather, False, dynamic_quant, dtype))
|
||||
|
||||
torch_quant, torch_output_scale = self.op_impl_base(input,
|
||||
input_scale,
|
||||
None,
|
||||
token_count,
|
||||
gather_ids,
|
||||
gather_index_start_position,
|
||||
None,
|
||||
None,
|
||||
dynamic_quant)
|
||||
notify_start = torch.mlu.Event(enable_timing=True)
|
||||
notify_end = torch.mlu.Event(enable_timing=True)
|
||||
notify_start.record()
|
||||
loop = 10
|
||||
for _ in range(loop):
|
||||
if dynamic_quant:
|
||||
tmo_output, tmo_output_scale = \
|
||||
tmo.moe_quantize(input, input_scale, None, token_count,
|
||||
gather_ids, gather_index_start_position,
|
||||
None, None, dynamic_quant)
|
||||
else:
|
||||
tmo_output, = \
|
||||
tmo.moe_quantize(input, input_scale, None, token_count,
|
||||
gather_ids, gather_index_start_position,
|
||||
None, None, dynamic_quant)
|
||||
|
||||
notify_end.record()
|
||||
notify_end.synchronize()
|
||||
time = notify_start.hardware_time(notify_end) / loop
|
||||
|
||||
tmo_output = tmo_output[:effective_count]
|
||||
if tmo_output_scale is not None:
|
||||
tmo_output_scale = tmo_output_scale[:effective_count]
|
||||
print("time is: {:.1f}us".format(time))
|
||||
|
||||
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
|
||||
tmo_output.cpu().reshape(-1).float(),
|
||||
0.01, use_MSE=True, use_RAE=True)
|
||||
if dynamic_quant:
|
||||
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
|
||||
tmo_output_scale.cpu().reshape(-1).float(),
|
||||
0.003, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
multi_scale_list = [True, False]
|
||||
need_gather_list = [True, False] if 'MLU370' not in torch_mlu.mlu.get_device_name() else [False]
|
||||
dynamic_quant_list = [True, False]
|
||||
num_tokens, hidden_size, num_expert, start_expert_id, expert_size, topk, \
|
||||
dtype, device = 1, 2048, 32, 24, 8, 5, torch.float16, 'mlu'
|
||||
params = product(multi_scale_list, need_gather_list, dynamic_quant_list)
|
||||
# quantize
|
||||
print(f"check ops.quantize...")
|
||||
input, input_scale, gather_ids, token_count, _, _, _ = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
False,
|
||||
False,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
dtype,
|
||||
device)
|
||||
output = torch.empty(input.size(), dtype=torch.int8, device=device)
|
||||
args = (input, input_scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
# per_token_smooth_quantize
|
||||
print(f"check ops.per_token_smooth_quantize...")
|
||||
output_scale = torch.empty(input.size()[:-1], dtype=input_scale.dtype, device=device)
|
||||
args = (input, input_scale, output, output_scale, None, token_count, None, None, 'per_token', True)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
# moe_quantize
|
||||
for multi_scale, need_gather, dynamic_quant in params:
|
||||
if not multi_scale and need_gather:
|
||||
continue
|
||||
print(f"check ops.moe_quantize multi_scale: {multi_scale}, need_gather: {need_gather}, dynamic_quant: {dynamic_quant} ...")
|
||||
input, input_scale, gather_ids, token_count, _, \
|
||||
gather_index_start_position, _ = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
multi_scale,
|
||||
need_gather,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
dtype,
|
||||
device)
|
||||
output_shape = list(input.size())
|
||||
output_scale_shape = list(input.size()[:-1])
|
||||
if gather_ids is not None:
|
||||
output_tokens = gather_ids.size(0)
|
||||
output_shape[0] = output_tokens
|
||||
output_scale_shape[0] = output_tokens
|
||||
output = torch.empty(output_shape, dtype=torch.int8, device=device)
|
||||
output_scale = torch.empty(output_scale_shape, dtype=input_scale.dtype, device=device) if dynamic_quant else None
|
||||
args = (input, input_scale,
|
||||
output, output_scale, None,
|
||||
token_count, gather_ids, gather_index_start_position,
|
||||
'per_token', dynamic_quant)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSmoothQuantOp))
|
||||
217
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_group_gemm.py
Executable file
217
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_group_gemm.py
Executable file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
from torch_mlu import mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def gen_mix_w4w8_param(bc, seq, k, n, experts_num, topk, data_type, has_bias, quant_wise):
|
||||
bs = bc * seq
|
||||
token_topk = bs * topk
|
||||
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
|
||||
sorted_expert_id, indices = expert_id.sort()
|
||||
gather_idx = indices // topk
|
||||
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
|
||||
quant_group = k // quant_wise
|
||||
quant_flag = random.choices([4,8], k=experts_num * quant_group)
|
||||
b_count = (sum(quant_flag) // 4) * (quant_wise // 2) * n
|
||||
b = torch.randint(-128, 127, (b_count,), dtype=torch.int32, device="mlu").to(torch.int8)
|
||||
b_scale = torch.normal(0, 0.01, (quant_group, experts_num, n), device="mlu", dtype=torch.float32)
|
||||
a = torch.randint(-128, 127, (bs, k), device="mlu", dtype=torch.int32).to(torch.int8)
|
||||
a = a[gather_idx]
|
||||
a_scale = torch.normal(0, 0.01, (bs,), device="mlu", dtype=torch.float32)
|
||||
a_scale = a_scale[gather_idx]
|
||||
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
|
||||
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None
|
||||
return a, b, token_count, None, c, None, None, a_scale, b_scale, data_type, bs, bias, quant_flag
|
||||
|
||||
def gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False, quant_bit = 8, quant_group = 1):
|
||||
bs = batch * seq
|
||||
token_topk = bs * topk
|
||||
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
|
||||
sorted_expert_id, indices = expert_id.sort()
|
||||
gather_idx = indices // topk
|
||||
gather_idx = gather_idx.to(torch.int32)
|
||||
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
|
||||
|
||||
a = torch.randn(bs, k, device="mlu", dtype=data_type)
|
||||
if not idx_mode:
|
||||
a = a[gather_idx]
|
||||
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
|
||||
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
|
||||
|
||||
a, a_scale = QuantByRow(a, 8)
|
||||
if idx_mode:
|
||||
a_scale = a_scale[gather_idx]
|
||||
b_shape = b.shape
|
||||
b, b_scale = QuantByRow(b.view(-1, b.shape[-1]), quant_bit, quant_group)
|
||||
b = b.view(b_shape)
|
||||
if quant_bit == 4:
|
||||
b = PairlyPackInt8(b)
|
||||
b_scale = b_scale.view(experts_num, -1) if quant_group == 1 else b_scale.view(experts_num, -1, quant_group).permute(2, 0, 1).contiguous()
|
||||
alpha = None
|
||||
beta = None
|
||||
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None
|
||||
gather_idx_ = gather_idx if idx_mode else None
|
||||
quant_flag = None
|
||||
return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bs, bias, quant_flag
|
||||
|
||||
class TestSmoothQuantGroupGemmOp(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
a = create_tensor_from_dic(dic['a'])
|
||||
b = create_tensor_from_dic(dic['b'])
|
||||
m_list = dic['m_list']['data']
|
||||
expand_idx = dic['expand_idx']['data']
|
||||
c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c'])
|
||||
alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha'])
|
||||
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
|
||||
a_scale = create_tensor_from_dic(dic['a_scale'], 0, 0.01)
|
||||
b_scale = create_tensor_from_dic(dic['b_scale'], 0, 0.01)
|
||||
dtype = dic['dtype']['data']
|
||||
max_m = dic['max_m']['data']
|
||||
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
|
||||
quant_flag = dic['quant_flag']['data']
|
||||
self.launch(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag)
|
||||
|
||||
def launch(self, *args):
|
||||
total_m = args[2].sum().item()
|
||||
torch_out = self.op_impl_base(*args)
|
||||
tmo_out = ops.smooth_quant_group_gemm(*args)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag = args
|
||||
a = a.reshape(-1, a.size(-1))
|
||||
if expand_idx is not None:
|
||||
a = a[expand_idx]
|
||||
total_m = m_list.sum().item()
|
||||
a_list = a[:total_m].split(tuple(m_list))
|
||||
|
||||
c_list = []
|
||||
if c is not None:
|
||||
c = c.reshape(-1, c.size(-1))
|
||||
c_list = c[:total_m].split(tuple(m_list))
|
||||
|
||||
if a_scale is not None:
|
||||
a_scale_list = a_scale[:total_m].split(tuple(m_list))
|
||||
|
||||
k = a.shape[1]
|
||||
n = b.size(1) if quant_flag is None else b_scale.shape[2]
|
||||
if b_scale is not None and b_scale.dim() == 3: # for quant_grouped
|
||||
b_scale = b_scale.transpose(0, 1).contiguous()
|
||||
if quant_flag is not None:
|
||||
quant_group = b_scale.shape[1]
|
||||
group_wise = k // quant_group
|
||||
quant_flag = torch.tensor(quant_flag).view(-1, quant_group)
|
||||
b_offset_cu = torch.cumsum(quant_flag.sum(dim=1), dim=0) // 4 * (group_wise // 2) * n
|
||||
b_offset_cu = torch.nn.functional.pad(b_offset_cu, (1,0), "constant", 0)
|
||||
|
||||
output_list = []
|
||||
experts = b.size(0) if quant_flag is None else b_scale.size(0)
|
||||
for i in range(experts):
|
||||
if (a_list[i].size(0) > 0):
|
||||
if a_scale is not None and b_scale is not None:
|
||||
if quant_flag is None:
|
||||
gemm_out = smooth_quant_matmul(a_list[i], a_scale_list[i], b[i], b_scale[i], dtype)
|
||||
else:
|
||||
gemm_out = smooth_quant_matmul_w4w8_mixed(a_list[i], a_scale_list[i],
|
||||
b[b_offset_cu[i]:b_offset_cu[i+1]],
|
||||
b_scale[i], dtype, quant_flag = quant_flag[i])
|
||||
else:
|
||||
gemm_out = F.linear(a_list[i], b[i])
|
||||
if bias is not None:
|
||||
gemm_out += bias[i]
|
||||
if alpha is not None:
|
||||
gemm_out *= alpha[i]
|
||||
if beta is not None and c_list != []:
|
||||
gemm_out += c_list[i] * beta[i]
|
||||
output_list.append(gemm_out)
|
||||
real_res = torch.cat(output_list, dim=0)
|
||||
output = torch.empty(a.shape[0], n, device=real_res.device).to(real_res.dtype)
|
||||
output[:total_m] = real_res
|
||||
return output
|
||||
|
||||
def test_smooth_quant_group_gemm(self):
|
||||
bs_list = [1, 3]
|
||||
seq_list = [5, 8]
|
||||
k_list = [512, 1024]
|
||||
n_list = [512, 768, 2048]
|
||||
expert_list = [8, 32]
|
||||
topk_list = [2, 5]
|
||||
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
|
||||
has_bias_list = [True, False]
|
||||
|
||||
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list)
|
||||
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args:
|
||||
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \
|
||||
dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True)
|
||||
param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
|
||||
torch_out = self.op_impl_base(*param)
|
||||
tmo_out = ops.smooth_quant_group_gemm(*param)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
|
||||
def test_sq_group_gemm_quant_group(self):
|
||||
bs_list = [1, 3]
|
||||
seq_list = [5, 8]
|
||||
k_list = [512, 1024]
|
||||
n_list = [512, 768, 2048]
|
||||
expert_list = [8, 32]
|
||||
topk_list = [2, 5]
|
||||
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False]
|
||||
quant_bit_list = [4, 8]
|
||||
quant_group_size_list = [128, 256]
|
||||
has_bias_list = [True, False]
|
||||
|
||||
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list, quant_bit_list, quant_group_size_list)
|
||||
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group_size in args:
|
||||
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, \
|
||||
topk: {topk}, dtype: {data_type}, idx_mode: {idx_mode}, has_bias:{has_bias}, quant_bit: {quant_bit}, quant_group_size: {quant_group_size} testing...", flush=True)
|
||||
quant_group = k // quant_group_size
|
||||
param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group)
|
||||
torch_out = self.op_impl_base(*param)
|
||||
tmo_out = ops.smooth_quant_group_gemm(*param)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_sq_group_gemm_w4w8_mixed(self):
|
||||
bs_l = [1, 3]
|
||||
seq_l = [5, 8]
|
||||
k_l = [1024, 2048, 3072]
|
||||
n_l = [512, 768, 2048]
|
||||
expert_l = [8, 32]
|
||||
topk_l = [2, 5]
|
||||
dtype_l = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
has_bias_l = [True, False]
|
||||
group_wise_l = [128, 256, 512]
|
||||
|
||||
args = product(bs_l, seq_l, k_l, n_l, expert_l, topk_l, dtype_l, has_bias_l, group_wise_l)
|
||||
for bc, seq, k, n, experts, topk, data_type, has_bias, group_wise in args:
|
||||
print(f"bs: {bc}, seq_len: {seq}, k: {k}, n: {n}, experts: {experts}, \
|
||||
topk: {topk}, dtype: {data_type}, has_bias: {has_bias}, group_wise: {group_wise}, testing...", flush=True)
|
||||
param = gen_mix_w4w8_param(bc, seq, k, n, experts, topk, data_type, has_bias, group_wise)
|
||||
torch_out = self.op_impl_base(*param)
|
||||
tmo_out = ops.smooth_quant_group_gemm(*param)
|
||||
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
|
||||
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
|
||||
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
|
||||
has_bias_list = [True, False]
|
||||
args = product( dtype_list, idx_list, has_bias_list)
|
||||
for data_type, idx_mode, has_bias in args:
|
||||
args = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
|
||||
new_args = list(args)[:9]
|
||||
new_args.extend([args[-2], "half" if data_type == torch.half else "bfloat16", args[-1], None, args[-3]])
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, new_args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSmoothQuantGroupGemmOp))
|
||||
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
dtype_dict = {
|
||||
torch.half: "half",
|
||||
torch.bfloat16: "bfloat16",
|
||||
}
|
||||
dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
|
||||
# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16
|
||||
class TestSmoothQuantMatmulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
a, a_scale, b, b_scale, dtype, bias, c, act_mode, alpha, beta, use_hp_active = args
|
||||
if a_scale is not None:
|
||||
a = torch.mul(a, a_scale.unsqueeze(-1)).to(dtype)
|
||||
if b_scale is not None:
|
||||
b = torch.mul(b, b_scale.unsqueeze(-1)).to(dtype)
|
||||
output = torch.matmul(a, b.permute(1,0))
|
||||
if bias is not None:
|
||||
output += bias
|
||||
output = torch.mul(output, alpha)
|
||||
if c is not None:
|
||||
residual = torch.mul(c, beta)
|
||||
output = torch.add(output, residual)
|
||||
if act_mode != "none":
|
||||
act = act_mode_dict[act_mode]
|
||||
output = act(output)
|
||||
return output
|
||||
|
||||
def test_smooth_quant_matmul(self):
|
||||
m_list = [32, 64, 128]
|
||||
n_list = [64, 128, 256]
|
||||
k_list = [128, 256, 512]
|
||||
has_bias_list = [True, False]
|
||||
has_c_list = [True, False]
|
||||
act_mode_list = ["none", "silu", "gelu"]
|
||||
use_hp_active_list = [True, False]
|
||||
|
||||
args = product(m_list, n_list, k_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list)
|
||||
for m, n, k, has_bias, has_c, act_mode, dtype, use_hp_active in args:
|
||||
if has_c and act_mode != "none":
|
||||
continue
|
||||
a = torch.randn(m, k, device="mlu", dtype=dtype)
|
||||
b = torch.randn(n, k, device="mlu", dtype=dtype)
|
||||
bias, c = None, None
|
||||
if has_bias:
|
||||
bias = torch.randn(n, device="mlu", dtype=dtype)
|
||||
if has_c:
|
||||
c = torch.randn(m, n, device="mlu", dtype=dtype)
|
||||
|
||||
input_smooth = torch.randn(k, device="mlu", dtype=torch.float).abs() + 0.1
|
||||
quant_input, input_scale = QuantByRow(a * input_smooth, 8)
|
||||
quant_weight, weight_scale = QuantByRow(b / input_smooth, 8)
|
||||
|
||||
torch_output = self.op_impl_base(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c,
|
||||
act_mode, 1.0, 1.0, use_hp_active)
|
||||
tmo_output = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c,
|
||||
act_mode, 1.0, 1.0, use_hp_active)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.006, use_MSE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
has_bias_list = [True, False]
|
||||
has_c_list = [True, False]
|
||||
act_mode_list = ["none", "silu", "gelu"]
|
||||
arg = product(has_c_list, has_bias_list, act_mode_list, dtype_list)
|
||||
for has_c, has_bias, act_mode, dtype in arg:
|
||||
if has_c and act_mode != "none":
|
||||
continue
|
||||
print(f"===has_c: {has_c}, has_bias: {has_bias}, act_mode: {act_mode}, dtype: {dtype}===")
|
||||
M, K, N = 2, 16, 32
|
||||
quant_bit_size, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, True, 1., 0.8, 0.3, False, True
|
||||
a = torch.randint(0, 10, (M, K), dtype=torch.int8).mlu()
|
||||
b = torch.randint(0, 10, (N, K), dtype=torch.int8).mlu()
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_c else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
|
||||
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
|
||||
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
|
||||
a_zero, b_zero, c_zero = None, None, None
|
||||
c_scale, gemm_output_scale, gemm_output_zero = None, None, None
|
||||
quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel"
|
||||
str_dtype = dtype_dict[dtype]
|
||||
|
||||
args = [a, a_scale, a_zero,
|
||||
b, b_scale, b_zero,
|
||||
bias, c, c_scale, c_zero,
|
||||
gemm_output_scale, gemm_output_zero,
|
||||
str_dtype, None, quant_algo,
|
||||
a_quant_layout, b_quant_layout,
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef,
|
||||
alpha, beta, trans_a, trans_b,]
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestSmoothQuantMatmulOp))
|
||||
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_swap_blocks.py
Executable file
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_swap_blocks.py
Executable file
@@ -0,0 +1,88 @@
|
||||
import random
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
import os
|
||||
|
||||
def gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair):
|
||||
shape = (num_blocks, num_heads, block_size, head_size)
|
||||
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
|
||||
info = torch.iinfo(dtype)
|
||||
if cpy == "mlu to mlu":
|
||||
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
|
||||
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
|
||||
elif cpy == "mlu to cpu":
|
||||
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
|
||||
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu()
|
||||
elif cpy == "cpu to mlu":
|
||||
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu()
|
||||
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
|
||||
else:
|
||||
print("unkown copy direction.", flush=True)
|
||||
exit(1)
|
||||
else:
|
||||
if cpy == "mlu to mlu":
|
||||
src = torch.randn(size=shape, dtype=dtype).mlu()
|
||||
dst = torch.randn(size=shape, dtype=dtype).mlu()
|
||||
elif cpy == "mlu to cpu":
|
||||
src = torch.randn(size=shape, dtype=dtype).mlu()
|
||||
dst = torch.randn(size=shape, dtype=dtype).cpu()
|
||||
elif cpy == "cpu to mlu":
|
||||
src = torch.randn(size=shape, dtype=dtype).cpu()
|
||||
dst = torch.randn(size=shape, dtype=dtype).mlu()
|
||||
else:
|
||||
print("unkown copy direction.", flush=True)
|
||||
exit(1)
|
||||
|
||||
values = list(range(num_pair))
|
||||
random.shuffle(values)
|
||||
src_to_dst = {key: value for key, value in zip(range(num_pair), values)}
|
||||
return dst, src, src_to_dst
|
||||
|
||||
class TestSwapBlocksOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
dst, src, block_mapping = args
|
||||
for key, value in block_mapping.items():
|
||||
dst[value] = src[key]
|
||||
return dst
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test swap blocks due to ASan issues")
|
||||
def test_swap_blocks(self):
|
||||
num_blocks_list = [3600]
|
||||
num_heads_list = [8]
|
||||
head_size_list = [64,128]
|
||||
block_size_list = [16]
|
||||
num_pairs_list = [6,512]
|
||||
|
||||
types = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
types.append(torch.bfloat16)
|
||||
cpys = ["mlu to mlu", "mlu to cpu", "cpu to mlu"]
|
||||
args = product(num_blocks_list, num_heads_list, block_size_list, head_size_list, types, cpys, num_pairs_list)
|
||||
for num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair in args:
|
||||
print("num_blocks: {}, num_heads: {}, block_size: {}, head_size: {}, dtype: {}, dir: {}, num_pairs: {} testing..."
|
||||
.format(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair), flush=True)
|
||||
|
||||
dst, src, src_to_dst = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair)
|
||||
|
||||
ref_src, ref_dst = src.clone(), dst.clone()
|
||||
# cpu
|
||||
self.op_impl_base(ref_dst, ref_src, src_to_dst)
|
||||
# mlu
|
||||
ops.swap_blocks(dst, src, src_to_dst)
|
||||
# diff
|
||||
self.assertTensorsEqual(src.cpu().float(), ref_src.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
self.assertTensorsEqual(dst.cpu().float(), ref_dst.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
|
||||
|
||||
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
|
||||
def test_inductor(self):
|
||||
num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair = 3600, 8, 16, 64, torch.half, "mlu to mlu", 512
|
||||
dst, src, block_mapping = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.swap_blocks, (dst, src, block_mapping))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestSwapBlocksOp))
|
||||
150
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_update_out_and_lse.py
Executable file
150
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_update_out_and_lse.py
Executable file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
from torch_mlu import mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
from torch.nn import functional as F
|
||||
import time
|
||||
|
||||
def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
|
||||
if not pack:
|
||||
out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
|
||||
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
|
||||
block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
|
||||
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
|
||||
seq_offset = torch.randint(low=0, high=(max_seq_len - block_seq_len + 1), size=(batch, ), dtype=torch.int32, device="mlu")
|
||||
cu_seqs = None
|
||||
block_cu_seqs = None
|
||||
else:
|
||||
seq_lens = torch.randint(low=1, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
|
||||
block_seq_lens = torch.randint(low=1, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
|
||||
block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
|
||||
seq_offset = torch.zeros_like(seq_lens)
|
||||
for i in range(batch):
|
||||
seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
|
||||
seq_offset = seq_offset.mlu()
|
||||
cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
|
||||
block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
|
||||
total_seqs = torch.sum(seq_lens)
|
||||
block_total_seqs = torch.sum(block_seq_lens)
|
||||
|
||||
out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
|
||||
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
|
||||
block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
|
||||
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
|
||||
return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
class TestUpdateOutAndLse(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs = args
|
||||
return update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs)
|
||||
|
||||
def test_update_out_and_lse(self):
|
||||
test_case_num = 10
|
||||
dtype_choice = [torch.float16, torch.float32]
|
||||
random.seed(time.time())
|
||||
|
||||
count = test_case_num
|
||||
while count > 0:
|
||||
batch = random.randint(1, 16 + 1)
|
||||
head_num = random.randint(1, 16 + 1)
|
||||
head_size = random.randint(1, 512 + 1)
|
||||
block_seq_len = random.choice([1, random.randint(2, 2048 + 1)])
|
||||
max_seq_len = 1 if block_seq_len == 1 else max(random.randint(2, 2048 + 1), block_seq_len)
|
||||
pack = random.choice([True, False])
|
||||
|
||||
# 避免测试出现mlu显存不够
|
||||
if batch * head_num * head_size * max_seq_len > 10 * 1024 * 1024 * 1024:
|
||||
continue
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_choice.append(torch.bfloat16)
|
||||
else:
|
||||
if batch * head_num * head_size * max_seq_len > 1 * 1024 * 1024 * 1024:
|
||||
continue
|
||||
dtype = random.choice(dtype_choice)
|
||||
|
||||
print(f"test_update_out_and_lse] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
|
||||
|
||||
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
|
||||
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
|
||||
|
||||
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
count -= 1
|
||||
print(f"[test_update_out_and_lse] {test_case_num} cases test pass")
|
||||
|
||||
def test_combine_ring_attn(self):
|
||||
batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack =\
|
||||
1, 8, 128, 8192, 8192, torch.bfloat16, True
|
||||
|
||||
if not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.float16
|
||||
|
||||
print(f"[test_combine_ring_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
|
||||
|
||||
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
|
||||
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
|
||||
# total_seq和block_total_seq为所有batch的真实seq长度之和,假设所有seq的
|
||||
# out: [sum(out_seqs), 64, 128] block_out [sum(block_out_seqs), 64, 128]
|
||||
# lse: [1, 8, 8192] block_lse [1, 8, 8192]
|
||||
# seq_offset: [128]
|
||||
# cu_seqs block_cu_seqs: [128 + 1]
|
||||
|
||||
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
print("[test_combine_ring_attn] pass")
|
||||
|
||||
def test_combine_decoder_attn(self):
|
||||
batch_list = [16, 128]
|
||||
head_num_list = [64]
|
||||
head_size_list = [128]
|
||||
block_seq_len_list = [1]
|
||||
max_seq_len_list = [1]
|
||||
dtype_list = [torch.float16]
|
||||
pack_list = [False]
|
||||
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
|
||||
args = product(batch_list, head_num_list, head_size_list, block_seq_len_list, max_seq_len_list,
|
||||
dtype_list, pack_list)
|
||||
|
||||
for batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack in args:
|
||||
print(f"[test_combine_decoder_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
|
||||
|
||||
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
|
||||
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
|
||||
# out: [128, 1, 64, 128] block_out [128, 1, 64, 128]
|
||||
# lse: [128, 64, 1] block_lse [128, 1]
|
||||
# seq_offset cu_seqs block_cu_seqs: None
|
||||
|
||||
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
|
||||
print("[test_combine_decoder_attn] pass")
|
||||
|
||||
def test_inductor(self):
|
||||
pack_list = [True, False]
|
||||
for pack in pack_list:
|
||||
batch, head_num, head_size, dtype = 16, 8, 128, torch.float16
|
||||
if pack:
|
||||
block_seq_len, max_seq_len = 1024, 2048
|
||||
else:
|
||||
block_seq_len, max_seq_len = 1, 1
|
||||
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
|
||||
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.update_out_and_lse, (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs))
|
||||
print("[test_update_out_and_lse] test_inductor check pass")
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestUpdateOutAndLse))
|
||||
113
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_weight_only_quant_matmul.py
Executable file
113
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_weight_only_quant_matmul.py
Executable file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
from itertools import product
|
||||
|
||||
dtype_dict = {
|
||||
torch.half: "half",
|
||||
torch.bfloat16: "bfloat16",
|
||||
}
|
||||
dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
|
||||
# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16
|
||||
class TestWeightOnlyQuantMatmulOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
a, b, scale, zero, bias, c, act_mode, quant_bit_size, alpha, beta, use_hp_active = args
|
||||
if quant_bit_size == 4:
|
||||
n = b.shape[0]
|
||||
b = UnpackInt4(b).view(n, -1)
|
||||
if scale is not None:
|
||||
if scale.dim() == 2:
|
||||
group_size = b.size(1) // scale.size(1)
|
||||
scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape)
|
||||
else:
|
||||
scale_bd = scale.unsqueeze(-1)
|
||||
b = torch.mul(b, scale_bd).to(a.dtype)
|
||||
output = torch.matmul(a, b.permute(1,0))
|
||||
if bias is not None:
|
||||
output += bias
|
||||
output = torch.mul(output, alpha)
|
||||
if c is not None:
|
||||
residual = torch.mul(c, beta)
|
||||
output = torch.add(output, residual)
|
||||
if act_mode != "none":
|
||||
act = act_mode_dict[act_mode]
|
||||
output = act(output)
|
||||
return output
|
||||
|
||||
def test_weight_only_quant_matmul(self):
|
||||
m_list = [32, 64, 128]
|
||||
n_list = [128, 256, 512]
|
||||
group_list = [1, 2]
|
||||
k_list = [128, 256, 512]
|
||||
quant_bit_list = [8, 4]
|
||||
has_bias_list = [True, False]
|
||||
has_c_list = [True, False]
|
||||
act_mode_list = ["none", "silu", "gelu"]
|
||||
use_hp_active_list = [True, False]
|
||||
|
||||
args = product(m_list, n_list, k_list, group_list, quant_bit_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list)
|
||||
for m, n, k, group, quant_bit, has_bias, has_c, act_mode, dtype, use_hp_active in args:
|
||||
if has_c and act_mode != "none":
|
||||
continue
|
||||
a = torch.randn(m, k, device="mlu", dtype=dtype)
|
||||
b = torch.randn(n, k, device="mlu", dtype=dtype)
|
||||
bias, c = None, None
|
||||
zero = None
|
||||
if has_bias:
|
||||
bias = torch.randn(n, device="mlu", dtype=dtype)
|
||||
if has_c:
|
||||
c = torch.randn(m, n, device="mlu", dtype=dtype)
|
||||
quant_weight_int8, weight_scale = QuantByRow(b, quant_bit, group)
|
||||
if group != 1:
|
||||
if act_mode != "none":
|
||||
continue
|
||||
weight_scale = weight_scale.to(a.dtype)
|
||||
if quant_bit == 4:
|
||||
quant_weight_int4 = PairlyPackInt8(quant_weight_int8)
|
||||
|
||||
args = (a, quant_weight_int4 if quant_bit == 4 else quant_weight_int8,
|
||||
weight_scale, zero, bias, c, act_mode, quant_bit, 1.0, 1.0, use_hp_active)
|
||||
torch_output = self.op_impl_base(*args)
|
||||
tmo_output = ops.weight_only_quant_matmul(*args)
|
||||
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.004, use_MSE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
M, K, N, group_num = 2, 256, 32, 4
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, 'none', True, 1., 0.8, 0.3, False, True
|
||||
has_res_list = [True, False]
|
||||
group_quant_list = [True, False]
|
||||
args = product(has_res_list, group_quant_list, dtype_list)
|
||||
for has_res, group_quant, dtype in args:
|
||||
print(f"==has_res: {has_res}, group_quant: {group_quant}, dtype: {dtype}==")
|
||||
a = torch.randn((M, K), dtype=dtype).mlu()
|
||||
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
|
||||
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
|
||||
bias = torch.randn(N, device="mlu", dtype=dtype)
|
||||
group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype)
|
||||
b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel"
|
||||
b_scale = group_wise_scale if group_quant else None
|
||||
gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float)
|
||||
a_scale, a_zero, b_zero, c_zero = None, None, None, None
|
||||
c_scale, gemm_output_zero = None, None
|
||||
quant_algo, a_quant_layout = "weight_only", "quantize_none"
|
||||
dtype_str = dtype_dict[dtype]
|
||||
|
||||
args = [a, a_scale, a_zero,
|
||||
b, b_scale, b_zero,
|
||||
bias, c, c_scale, c_zero,
|
||||
gemm_output_scale, gemm_output_zero,
|
||||
dtype_str, None, quant_algo,
|
||||
a_quant_layout, b_quant_layout,
|
||||
quant_bit_size, act_mode, use_hp_active, act_coef,
|
||||
alpha, beta, trans_a, trans_b,]
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
exit(run_unittest(TestWeightOnlyQuantMatmulOp))
|
||||
Reference in New Issue
Block a user