Files
2026-02-04 17:39:32 +08:00

735 lines
32 KiB
Python

import sys
sys_args = sys.argv
sys.argv = [sys_args.pop(0)] # prevent unittest printing help info
import os
import torch
import torch_mlu
from torch.testing._internal.common_utils import TestCase
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from typing import List, Tuple, Optional
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group as get_default_group, all_reduce, ReduceOp
import torch.testing._internal.optests as optests
import random
import argparse
from abc import abstractmethod, ABC
import unittest
import torch_mlu_ops as tmo
import os
act_mode_dict = {"relu": torch.nn.functional.relu,
"gelu": torch.nn.functional.gelu,
"silu": torch.nn.functional.silu}
class BtTestCase(TestCase, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
os.environ['TORCH_ALLOW_TF32_CNMATMUL_OVERRIDE'] = '0'
@abstractmethod
def op_impl_base(self, *args):
pass
@abstractmethod
def test_inductor(self):
pass
def base_opcheck(self, interface_overload, args):
target_check = ["test_schema", "test_autograd_registration"]
if torch.__version__ >= '2.3.0':
target_check.append("test_faketensor")
target_status = {key: "SUCCESS" for key in target_check}
result = optests.opcheck(interface_overload, args, test_utils=target_check)
self.assertEqual(result, target_status,)
def assertException(self, error_msg, func, *args, **kwinputs):
try:
func(*args, **kwinputs)
self.assertTrue(False)
except Exception as e:
if error_msg:
self.assertTrue(error_msg == str(e))
else:
self.assertTrue(True)
def assertTensorsEqual(self,
a,
b,
prec=None,
message='',
allow_inf=False,
use_MSE=False,
use_RAE=False,
use_RMA=False):
'''unittest.TestCase'''
if a.dtype == torch.bool:
a = a.float()
if b.dtype == torch.bool:
b = b.float()
epsilon = 1.0 / 16384
self.assertEqual(a.size(), b.size(), message)
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
if a.numel() > 0:
# check that NaNs are in the same locations
nan_mask = a != a
self.assertTrue(torch.equal(nan_mask, b != b), message)
diff = a - b
diff[nan_mask] = 0
a = a.clone()
b = b.clone()
a[nan_mask] = 0
b[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = (a == float("inf")) | (a == float("-inf"))
self.assertTrue(torch.equal(inf_mask,
(b == float("inf")) | (b == float("-inf"))),
message)
diff[inf_mask] = 0
a[inf_mask] = 0
b[inf_mask] = 0
# TODO: implement abs on CharTensor
if diff.is_signed() and 'CharTensor' not in diff.type():
diff = diff.abs()
if use_MSE:
diff = diff.abs().pow(2).sum()
a_pow_sum = a.pow(2).sum()
if diff <= (2 * epsilon) * (2 * epsilon):
diff = 0.0
if a_pow_sum <= epsilon:
a_pow_sum = a_pow_sum + epsilon
diff = torch.div(diff, (a_pow_sum * 1.0))
self.assertLessEqual(diff.sqrt(), prec, message)
elif use_RAE:
diff = diff.abs().sum()
a_sum = a.abs().sum()
if a_sum == 0:
self.assertEqual(a, b, message)
else:
diff = torch.div(diff, a_sum)
self.assertLessEqual(diff, prec, message)
elif use_RMA:
a_mean = a.abs().mean()
b_mean = b.abs().mean()
if a_mean == 0:
self.assertEqual(a, b, message)
else:
diff = torch.div((a_mean - b_mean).abs(), a_mean)
self.assertLessEqual(diff, prec, message)
else:
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
def run_unittest(case) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('-k', nargs='+', type=str, default="", help='specify case to run')
args = parser.parse_args(sys_args)
if args.k != "":
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromNames(args.k, case))
else:
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(case))
return not ret.wasSuccessful()
class TMOTimer:
def __init__(self, repeat: int = 1):
self.repeat = repeat
def __enter__(self):
self.notify_start = torch.mlu.Event(enable_timing=True)
self.notify_end = torch.mlu.Event(enable_timing=True)
self.notify_start.record()
def __exit__(self, exc_type, exc_value, traceback):
self.notify_end.record()
self.notify_end.synchronize()
total_hardware_time = self.notify_start.hardware_time(self.notify_end)
self.average_hardware_time = total_hardware_time / self.repeat
def QuantByRow(input: torch.Tensor, quant_bit: int, group_num: int=1):
input_shape = input.shape
if input.dim() > 2:
input = input.view(-1, input_shape[-1])
if input.dim() == 1:
input = input.unsqueeze(0)
assert input.dim() == 2, "input must be 2-D tensor."
assert quant_bit == 4 or quant_bit == 8, "quant_bit must be 4 or 8."
assert group_num >= 1, "group_num >= 1."
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
group_size = input.size(-1) // group_num
input_v = input.view(input.size(0), group_num, group_size) if group_num > 1 else input
max, _ = input_v.abs().max(dim=-1, keepdim=True)
scale = max.to(torch.float) / int_max
quant_input = (input_v / scale).round().clamp(int_min, int_max).to(torch.int8).view(input.size())
return quant_input.view(input_shape), scale.squeeze(-1)
def QuantByTensor(input: torch.Tensor, quant_bit: int):
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
input_max = torch.max(torch.abs(input))
input_scale = int_max / input_max
input_int = torch.mul(input, input_scale).round().clamp(int_min, int_max).to(torch.int8)
return input_int, input_scale
def PairlyPackInt8(input):
assert input.dtype == torch.int8, "dtype of input must be int8."
assert input.dim() == 2 or input.dim() == 3, "input must be 2-D or 3-D tensor."
assert input.size(-1) % 2 == 0, "size(-1) of input must be even."
input_shape = list(input.shape)
input_flat = input.flatten()
d0 = input_flat[0::2].to(torch.uint8)
d1 = input_flat[1::2].to(torch.uint8)
dp = (d1 << 4) + (d0 & 0x0F)
input_shape[-1] = input_shape[-1] // 2
return dp.to(torch.int8).reshape(input_shape)
def UnpackInt4(input):
assert input.dtype == torch.int8, "dtype of input must be int8."
input_flat = input.flatten()
n = input_flat.size(0)
output = torch.zeros(n * 2, dtype=torch.int8, device=input.device)
high = input_flat >> 4
low = input_flat << 4
low = low >> 4
output[0::2] = low
output[1::2] = high
return output
def smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias=None):
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
assert a_scale.dim() == 1, "a_scale.dim() == 1"
assert a.size(0) == a_scale.size(0), "a.size(0) == a_scale.size(0)"
assert b.size(0) == b_scale.size(-1), "b.size(0) == b_scale.size(-1)"
m = a.size(0)
n = b.size(0)
a_k = a.size(1)
b_k = b.size(1)
if b_scale.dim() == 1:
b_scale = b_scale.unsqueeze(0)
quant_group = b_scale.size(0)
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
if a_k == b_k * 2:
b = UnpackInt4(b)
b = b.view(n, quant_group, -1).transpose(0, 1).contiguous()
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
for i in range(quant_group):
scale_mn = torch.matmul(a_scale.unsqueeze(1), b_scale[i].unsqueeze(0)) # (m, 1) x (1, n) = (m, n)
out += torch.einsum('mk,nk->mn', a[i].to(torch.float), b[i].to(torch.float)) * scale_mn
# out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
out = out.to(out_dtype)
if bias is not None:
out += bias
return out
def smooth_quant_matmul_w4w8_mixed(a, a_scale, b, b_scale, out_dtype, bias=None, quant_flag=None):
m = a.shape[0]
k = a.shape[1]
quant_group = b_scale.shape[0]
group_wise = k // quant_group
n = b_scale.shape[1]
b = b.view(n, -1)
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
new_b = []
start = 0
end = 0
for i in range(quant_group):
if quant_flag[i] == 4:
end += group_wise // 2
new_b.append(UnpackInt4(b[:, start:end]).view(n, -1))
else:
end += group_wise
new_b.append((b[:, start:end]))
start = end
new_b = torch.cat(new_b, 1)
b = new_b.view(n, quant_group, -1).transpose(0, 1).contiguous()
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
for i in range(quant_group):
out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
out = out.to(out_dtype)
if bias is not None:
out += bias
return out
def weight_only_quant_matmul(a, b, scale, bias=None):
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
assert scale.dim() == 1 or scale.dim() == 2, "scale.dim() == 1 or scale.dim() == 2"
assert b.size(0) == scale.size(0), "b.size(0) == b_scale.size(0)"
assert a.size(1) == b.size(1), "a.size(1) == b.size(1)"
if scale.dim() == 2:
group_size = b.size(1) // scale.size(1)
scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape)
else:
scale_bd = scale.unsqueeze(-1)
b1 = b * scale_bd
out = torch.einsum('mk,nk->mn', a.to(torch.float), b1.to(torch.float)).to(a.dtype)
if bias is not None:
out += bias
return out
def single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale,
v_cache_quant_scale, alibi_slopes, window_size_left, window_size_right, softmax_scale, return_lse):
q = q.float()
k_cache = k_cache.float()
v_cache = v_cache.float()
def masked_attention(query, key, value, alibi_slope, context_len, window_size_left, window_size_right, qk_scale) -> torch.Tensor:
# (num_heads, seq_q, seq_k)
qk = torch.einsum('qhd,hkd->hqk', query, key)
qk = qk * qk_scale
if alibi_slope is not None:
alibi_dist = torch.arange(0, context_len, dtype=torch.float32).mlu()
alibi = alibi_slope[:, None] * alibi_dist
qk = qk + alibi[:, None, :]
_, seq_q, seq_k = qk.size()
if seq_q > 1: #causal mask
ml = torch.zeros((seq_q, seq_k - seq_q), dtype=qk.dtype).mlu()
ones = torch.ones((seq_q, seq_q), dtype=qk.dtype).mlu() * -torch.inf
mr = torch.triu(ones, diagonal=1)
mask = torch.cat((ml, mr), dim=-1)
qk = qk + mask
if window_size_left != -1 or window_size_right != -1:
mask_w = torch.full((seq_q, seq_k), -torch.inf, dtype=torch.float, device="mlu")
for qi in range(seq_q):
left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0
right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k
mask_w[qi, left:right] = 0
qk += mask_w
attention = torch.softmax(qk, dim = -1, dtype=qk.dtype)
qkv = torch.einsum('hqk,hkd->qhd', attention, value)
return qkv, qk
if k_cache_quant_scale is not None and v_cache_quant_scale is not None:
if k_cache_quant_scale.dim() == 2: # per_channel: [kv_head_num, head_size]
k_cache_quant_scale = k_cache_quant_scale.reshape(1, k_cache_quant_scale.shape[0], 1, k_cache_quant_scale.shape[1])
v_cache_quant_scale = v_cache_quant_scale.reshape(1, v_cache_quant_scale.shape[0], 1, v_cache_quant_scale.shape[1])
elif k_cache_quant_scale.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
k_cache_quant_scale = k_cache_quant_scale.reshape(*k_cache_quant_scale.shape, 1)
v_cache_quant_scale = v_cache_quant_scale.reshape(*v_cache_quant_scale.shape, 1)
k_cache *= k_cache_quant_scale
v_cache *= v_cache_quant_scale
bs, seq_q, num_heads, head_size = q.size()
head_size_v = v_cache.size(-1)
num_blocks, num_kv_heads, block_size, _ = k_cache.size()
output = torch.zeros((bs, seq_q, num_heads, head_size_v), dtype=torch.float16)
lse = torch.zeros((bs, num_heads, seq_q), dtype=torch.float)
assert (num_heads % num_kv_heads == 0)
head_repeats = num_heads // num_kv_heads
for bs_id in range(bs):
q_bs = q[bs_id]
context_len = int(context_lens[bs_id])
if context_len == 0:
output[bs_id] = torch.zeros((seq_q, num_heads, head_size_v), device = q.device, dtype=output.dtype)
lse[bs_id] = lse[bs_id].fill_(-float('inf'))
else :
block_table = block_tables[bs_id]
table_end = (context_len + block_size - 1) // block_size
block_ids = block_table[0 : table_end]
keys, values = k_cache[block_ids], v_cache[block_ids]
keys = torch.repeat_interleave(keys, head_repeats, dim=1)
keys = keys.transpose(1, 0).contiguous().view(num_heads, -1, head_size)
keys = keys[:, 0:context_len, :]
values = torch.repeat_interleave(values, head_repeats, dim=1)
values = values.transpose(1, 0).contiguous().view(num_heads, -1, head_size_v)
values = values[:, 0:context_len, :]
alibi_slope = alibi_slopes[bs_id] if alibi_slopes is not None else None
qkv, qk= masked_attention(q_bs, keys, values, alibi_slope, context_len, window_size_left, window_size_right, softmax_scale)
output[bs_id] = qkv
lse[bs_id] = torch.logsumexp(qk, dim = -1)
return (output, lse) if return_lse else output
def update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs):
# only pad
is_pack = out.dim() == 3
new_out, new_lse = out.clone(), lse.clone()
batch, max_seq_len, block_seq_len = lse.shape[0], lse.shape[-1], block_lse.shape[-1]
lse_bsh = lse.transpose(-2, -1).unsqueeze(dim=-1)
new_lse_bsh = new_lse.transpose(-2, -1).unsqueeze(dim=-1)
block_lse_bsh = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
if not is_pack:
for i in range(batch):
out_seq_offset = 0 if seq_offsets is None else seq_offsets[i]
out_i = out[i, out_seq_offset : out_seq_offset + block_seq_len]
lse_i = lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len]
block_out_i = block_out[i, :]
block_lse_i = block_lse_bsh[i, :]
new_out[i, out_seq_offset : out_seq_offset + block_seq_len] = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
new_lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] = (lse_i - F.logsigmoid(lse_i - block_lse_i))
else:
for i in range(batch):
block_i_begin = block_cu_seqs[i]
block_i_end = block_cu_seqs[i + 1]
block_i_lens = block_i_end - block_i_begin
out_i_begin = cu_seqs[i]
out_seq_offset = seq_offsets[i]
block_out_i = block_out[block_i_begin : block_i_end]
block_lse_i = block_lse_bsh[i, 0 : block_i_lens]
out_i = out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens]
lse_i = lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens]
new_out_i = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
new_lse_i = (lse_i - F.logsigmoid(lse_i - block_lse_i))
new_out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] = new_out_i
new_lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] = new_lse_i
return (new_out, new_lse_bsh.squeeze(dim=-1).transpose(-2, -1))
class QuantMatmul(torch.nn.Module):
def __init__(self, weight, bias, residual, input_scale, weight_scale, gemm_output_scale, dtype,
alpha:float = 1.0, beta:float = 1.0, act_mode:str = 'none') -> None:
super().__init__()
self.dtype = dtype
self.weight = Parameter(weight.type(dtype))
self.input_scale = input_scale
self.weight_scale = weight_scale
self.gemm_output_scale = gemm_output_scale
if bias is not None:
self.bias = Parameter(bias)
else:
self.bias = None
if residual is not None:
self.residual = Parameter(residual)
else:
self.residual = None
self.alpha = alpha
self.beta = beta
if act_mode == 'none':
self.act = None
else:
self.act = act_mode_dict[act_mode]
# d = (a * b + bias) * alpha + c * beta
# output = (input * weight + bias) * alpha + residual * beta
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input.type(self.dtype), self.weight, self.bias)
if self.input_scale is not None:
i_scale = self.input_scale.expand(self.weight_scale.shape[0], -1).transpose(0, 1)
output = torch.mul(output, i_scale)
if self.weight_scale is not None:
output = torch.mul(output, self.weight_scale)
if self.gemm_output_scale is not None:
output = torch.mul(output, self.gemm_output_scale)
output = torch.mul(output, self.alpha)
if self.residual is not None:
residual = torch.mul(self.residual, self.beta)
output = torch.add(output, residual)
if self.act is not None:
output = self.act(output)
return output
# for multiprocessing
def assertTensorsEqual( a,
b,
prec=None,
message='',
allow_inf=False,
use_MSE=False,
use_RAE=False,
use_RMA=False):
tc = TestCase()
if a.dtype == torch.bool:
a = a.float()
if b.dtype == torch.bool:
b = b.float()
epsilon = 1.0 / 16384
tc.assertEqual(a.size(), b.size(), message)
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
if a.numel() > 0:
# check that NaNs are in the same locations
nan_mask = a != a
tc.assertTrue(torch.equal(nan_mask, b != b), message)
diff = a - b
diff[nan_mask] = 0
a = a.clone()
b = b.clone()
a[nan_mask] = 0
b[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = (a == float("inf")) | (a == float("-inf"))
tc.assertTrue(torch.equal(inf_mask,
(b == float("inf")) | (b == float("-inf"))),
message)
diff[inf_mask] = 0
a[inf_mask] = 0
b[inf_mask] = 0
# TODO: implement abs on CharTensor
if diff.is_signed() and 'CharTensor' not in diff.type():
diff = diff.abs()
if use_MSE:
diff = diff.abs().pow(2).sum()
a_pow_sum = a.pow(2).sum()
if diff <= (2 * epsilon) * (2 * epsilon):
diff = 0.0
if a_pow_sum <= epsilon:
a_pow_sum = a_pow_sum + epsilon
diff = torch.div(diff, (a_pow_sum * 1.0))
tc.assertLessEqual(diff.sqrt(), prec, message)
elif use_RAE:
diff = diff.abs().sum()
a_sum = a.abs().sum()
if a_sum == 0:
tc.assertEqual(a, b, message)
else:
diff = torch.div(diff, a_sum)
tc.assertLessEqual(diff, prec, message)
elif use_RMA:
a_mean = a.abs().mean()
b_mean = b.abs().mean()
if a_mean == 0:
tc.assertEqual(a, b, message)
else:
diff = torch.div((a_mean - b_mean).abs(), a_mean)
tc.assertLessEqual(diff, prec, message)
else:
max_err = diff.max()
tc.assertLessEqual(max_err, prec, message)
def setup(rank, world_size, backend='cncl'):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '3458'
dist.init_process_group(backend, rank=rank, world_size=world_size)
torch_mlu.mlu.set_device(rank)
def cleanup():
dist.barrier()
dist.destroy_process_group()
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]
def generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype,
quant_mode=False, offline=False, invalid_batch_size=0):
q_heads = 1
total_heads = q_heads + num_heads * 2
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 2,
dtype=torch.int32, device='mlu')
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
cache_bs_id = random.sample([*range(0, batch_size)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if invalid_batch_size > 0:
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch_size)] = -1
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
context_seq_offsets = None
else:
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cu_context_lens = context_lens
context = context.to(dtype)
key = context[..., q_heads:q_heads + num_heads, :]
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
cache_scale = None
if quant_mode:
cache = (cache - 0.5) * 256
cache = cache.to(torch.int8)
if offline:
cache_scale = torch.randn((2, cache.shape[2], cache.shape[4]), dtype=torch.float, device='mlu')
else:
cache_scale = torch.randn((2, max_bs, num_heads, cache_memory_len), dtype=torch.float, device='mlu')
else:
cache = cache.to(dtype)
block_size = 16 if "MLU3" not in torch.mlu.get_device_name() else max_context_len
min_blocks = (total_seqlen + block_size - 1) // block_size
num_blocks = min(min_blocks + 10, 2 * min_blocks)
num_slots = num_blocks * block_size
slot_mapping = random.sample(range(num_slots), total_seqlen.item())
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1
return [key, value, cache[0], cache[1], cu_context_lens, max_context_len,
packed > 0, context_seq_offsets, cache_bs_id, cache_seq_offsets,
cache_scale, slot_mapping]
def fused_moe(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
input_smooth: Optional[torch.Tensor],
act_smooth: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
topk: int,
renormalized: bool,
gated: bool,
act_mode: str,
start_expert_id: int = 0,
block_n: int = 0,
cncl_comm: int = 0,
w1_quant_flag: Optional[List] = None,
w2_quant_flag: Optional[List] = None):
dtype = hidden_states.dtype
ori_input_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
tokens = hidden_states.size(0)
gating_output = gating_output.reshape(-1, gating_output.size(-1))
residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None
expert_num = gating_output.size(-1)
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
per_token_sq = False
# check quant
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
# softmax_topk
reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalized)
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num)
if per_token_sq:
if torch.mlu.get_device_name() == 'MLU370':
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
quant_input, input_scale = tmo.moe_quantize(expand_hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size])
else:
quant_input, input_scale = tmo.moe_quantize(hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
cusum_token_count[start_expert_id].unsqueeze(0))
else:
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
# group gemm
if per_token_sq:
gemm1_out = tmo.smooth_quant_group_gemm(quant_input,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None,
input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag)
else:
gemm1_out = tmo.group_gemm(expand_hidden_states,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None,
None,
None,
None, tokens)
# add_bias_active
act_out = tmo.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size)
if per_token_sq:
quant_input, input_scale = tmo.moe_quantize(act_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
if cncl_comm > 0:
raise ValueError("not support communication and computing fusion currently.")
else:
if per_token_sq:
gemm2_out = tmo.smooth_quant_group_gemm(quant_input,
w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag)
else:
gemm2_out = tmo.group_gemm(act_out,
w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, tokens)
output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id,
expert_size, bias2)
return output.reshape(ori_input_shape)
def min_mem_size(shape, stride):
if stride is None:
mem_size = 0
mem_size += shape.numel()
else:
mem_size = 1
for k,v in zip(shape, stride):
mem_size += (k - 1) * v
return mem_size
def create_tensor(shape, dtype, is_contiguous, device, stride = None, mean=0, var=1, is_uniform=False, low=0, high=1):
if is_contiguous:
if dtype in (torch.int8, torch.uint8):
t = torch.randint(-128, 127, shape, device=device).to(dtype)
else:
if is_uniform:
t = torch.empty(shape, dtype=dtype, device=device).uniform_(low, high)
else:
t = torch.normal(mean, var, shape, dtype=dtype, device=device)
else:
mem_size = min_mem_size(shape, stride)
if dtype in (torch.int8, torch.uint8):
t = torch.randint(-128, 127, (mem_size,), device=device).to(dtype)
else:
if is_uniform:
t = torch.empty((mem_size,), dtype=dtype, device=device).uniform_(low, high)
else:
t = torch.normal(mean, var, (mem_size,), dtype=dtype, device=device)
t = t.as_strided(shape, stride)
return t
def create_tensor_from_dic(dic:dict, mean=0, var=1, is_uniform=False, low=0, high=1):
if dic['data'] is None:
return None
shape = dic['shape']
dtype = dic['dtype']
is_contiguous = dic['is_contiguous']
device = dic['device']
stride = dic['stride']
return create_tensor(shape, dtype, is_contiguous, device, stride, mean, var, is_uniform, low, high)
def create_op_param(dic: dict):
if dic['type'] in (list, tuple):
return [create_op_param(elem) for elem in dic['data']] if dic['has_compound'] else dic['data']
elif dic['type'] is dict:
return {k:create_op_param(v) for k,v in dic['data'].items()}
elif dic['type'] is torch.Tensor:
if dic['data'] is None:
return None
else:
if dic['dtype'] in (torch.int16, torch.int32, torch.int64):
return dic['data']
else:
return create_tensor(dic['shape'], dic['dtype'], dic['is_contiguous'], dic['device'], dic['stride'])
else:
return dic['data']