import sys sys_args = sys.argv sys.argv = [sys_args.pop(0)] # prevent unittest printing help info import os import torch import torch_mlu from torch.testing._internal.common_utils import TestCase from torch.nn import functional as F from torch.nn.parameter import Parameter from typing import List, Tuple, Optional import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group as get_default_group, all_reduce, ReduceOp import torch.testing._internal.optests as optests import random import argparse from abc import abstractmethod, ABC import unittest import torch_mlu_ops as tmo import os act_mode_dict = {"relu": torch.nn.functional.relu, "gelu": torch.nn.functional.gelu, "silu": torch.nn.functional.silu} class BtTestCase(TestCase, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) os.environ['TORCH_ALLOW_TF32_CNMATMUL_OVERRIDE'] = '0' @abstractmethod def op_impl_base(self, *args): pass @abstractmethod def test_inductor(self): pass def base_opcheck(self, interface_overload, args): target_check = ["test_schema", "test_autograd_registration"] if torch.__version__ >= '2.3.0': target_check.append("test_faketensor") target_status = {key: "SUCCESS" for key in target_check} result = optests.opcheck(interface_overload, args, test_utils=target_check) self.assertEqual(result, target_status,) def assertException(self, error_msg, func, *args, **kwinputs): try: func(*args, **kwinputs) self.assertTrue(False) except Exception as e: if error_msg: self.assertTrue(error_msg == str(e)) else: self.assertTrue(True) def assertTensorsEqual(self, a, b, prec=None, message='', allow_inf=False, use_MSE=False, use_RAE=False, use_RMA=False): '''unittest.TestCase''' if a.dtype == torch.bool: a = a.float() if b.dtype == torch.bool: b = b.float() epsilon = 1.0 / 16384 self.assertEqual(a.size(), b.size(), message) assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor." if a.numel() > 0: # check that NaNs are in the same locations nan_mask = a != a self.assertTrue(torch.equal(nan_mask, b != b), message) diff = a - b diff[nan_mask] = 0 a = a.clone() b = b.clone() a[nan_mask] = 0 b[nan_mask] = 0 # inf check if allow_inf=True if allow_inf: inf_mask = (a == float("inf")) | (a == float("-inf")) self.assertTrue(torch.equal(inf_mask, (b == float("inf")) | (b == float("-inf"))), message) diff[inf_mask] = 0 a[inf_mask] = 0 b[inf_mask] = 0 # TODO: implement abs on CharTensor if diff.is_signed() and 'CharTensor' not in diff.type(): diff = diff.abs() if use_MSE: diff = diff.abs().pow(2).sum() a_pow_sum = a.pow(2).sum() if diff <= (2 * epsilon) * (2 * epsilon): diff = 0.0 if a_pow_sum <= epsilon: a_pow_sum = a_pow_sum + epsilon diff = torch.div(diff, (a_pow_sum * 1.0)) self.assertLessEqual(diff.sqrt(), prec, message) elif use_RAE: diff = diff.abs().sum() a_sum = a.abs().sum() if a_sum == 0: self.assertEqual(a, b, message) else: diff = torch.div(diff, a_sum) self.assertLessEqual(diff, prec, message) elif use_RMA: a_mean = a.abs().mean() b_mean = b.abs().mean() if a_mean == 0: self.assertEqual(a, b, message) else: diff = torch.div((a_mean - b_mean).abs(), a_mean) self.assertLessEqual(diff, prec, message) else: max_err = diff.max() self.assertLessEqual(max_err, prec, message) def run_unittest(case) -> int: parser = argparse.ArgumentParser() parser.add_argument('-k', nargs='+', type=str, default="", help='specify case to run') args = parser.parse_args(sys_args) if args.k != "": ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromNames(args.k, case)) else: ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(case)) return not ret.wasSuccessful() class TMOTimer: def __init__(self, repeat: int = 1): self.repeat = repeat def __enter__(self): self.notify_start = torch.mlu.Event(enable_timing=True) self.notify_end = torch.mlu.Event(enable_timing=True) self.notify_start.record() def __exit__(self, exc_type, exc_value, traceback): self.notify_end.record() self.notify_end.synchronize() total_hardware_time = self.notify_start.hardware_time(self.notify_end) self.average_hardware_time = total_hardware_time / self.repeat def QuantByRow(input: torch.Tensor, quant_bit: int, group_num: int=1): input_shape = input.shape if input.dim() > 2: input = input.view(-1, input_shape[-1]) if input.dim() == 1: input = input.unsqueeze(0) assert input.dim() == 2, "input must be 2-D tensor." assert quant_bit == 4 or quant_bit == 8, "quant_bit must be 4 or 8." assert group_num >= 1, "group_num >= 1." int_max = float(2 ** (quant_bit - 1) - 1) int_min = -float(2 ** (quant_bit - 1)) group_size = input.size(-1) // group_num input_v = input.view(input.size(0), group_num, group_size) if group_num > 1 else input max, _ = input_v.abs().max(dim=-1, keepdim=True) scale = max.to(torch.float) / int_max quant_input = (input_v / scale).round().clamp(int_min, int_max).to(torch.int8).view(input.size()) return quant_input.view(input_shape), scale.squeeze(-1) def QuantByTensor(input: torch.Tensor, quant_bit: int): int_max = float(2 ** (quant_bit - 1) - 1) int_min = -float(2 ** (quant_bit - 1)) input_max = torch.max(torch.abs(input)) input_scale = int_max / input_max input_int = torch.mul(input, input_scale).round().clamp(int_min, int_max).to(torch.int8) return input_int, input_scale def PairlyPackInt8(input): assert input.dtype == torch.int8, "dtype of input must be int8." assert input.dim() == 2 or input.dim() == 3, "input must be 2-D or 3-D tensor." assert input.size(-1) % 2 == 0, "size(-1) of input must be even." input_shape = list(input.shape) input_flat = input.flatten() d0 = input_flat[0::2].to(torch.uint8) d1 = input_flat[1::2].to(torch.uint8) dp = (d1 << 4) + (d0 & 0x0F) input_shape[-1] = input_shape[-1] // 2 return dp.to(torch.int8).reshape(input_shape) def UnpackInt4(input): assert input.dtype == torch.int8, "dtype of input must be int8." input_flat = input.flatten() n = input_flat.size(0) output = torch.zeros(n * 2, dtype=torch.int8, device=input.device) high = input_flat >> 4 low = input_flat << 4 low = low >> 4 output[0::2] = low output[1::2] = high return output def smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias=None): assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2" assert a_scale.dim() == 1, "a_scale.dim() == 1" assert a.size(0) == a_scale.size(0), "a.size(0) == a_scale.size(0)" assert b.size(0) == b_scale.size(-1), "b.size(0) == b_scale.size(-1)" m = a.size(0) n = b.size(0) a_k = a.size(1) b_k = b.size(1) if b_scale.dim() == 1: b_scale = b_scale.unsqueeze(0) quant_group = b_scale.size(0) a = a.view(m, quant_group, -1).transpose(0, 1).contiguous() if a_k == b_k * 2: b = UnpackInt4(b) b = b.view(n, quant_group, -1).transpose(0, 1).contiguous() out = torch.zeros(m, n, dtype=torch.float, device=a.device) for i in range(quant_group): scale_mn = torch.matmul(a_scale.unsqueeze(1), b_scale[i].unsqueeze(0)) # (m, 1) x (1, n) = (m, n) out += torch.einsum('mk,nk->mn', a[i].to(torch.float), b[i].to(torch.float)) * scale_mn # out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype) out = out.to(out_dtype) if bias is not None: out += bias return out def smooth_quant_matmul_w4w8_mixed(a, a_scale, b, b_scale, out_dtype, bias=None, quant_flag=None): m = a.shape[0] k = a.shape[1] quant_group = b_scale.shape[0] group_wise = k // quant_group n = b_scale.shape[1] b = b.view(n, -1) a = a.view(m, quant_group, -1).transpose(0, 1).contiguous() new_b = [] start = 0 end = 0 for i in range(quant_group): if quant_flag[i] == 4: end += group_wise // 2 new_b.append(UnpackInt4(b[:, start:end]).view(n, -1)) else: end += group_wise new_b.append((b[:, start:end])) start = end new_b = torch.cat(new_b, 1) b = new_b.view(n, quant_group, -1).transpose(0, 1).contiguous() out = torch.zeros(m, n, dtype=torch.float, device=a.device) for i in range(quant_group): out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype) out = out.to(out_dtype) if bias is not None: out += bias return out def weight_only_quant_matmul(a, b, scale, bias=None): assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2" assert scale.dim() == 1 or scale.dim() == 2, "scale.dim() == 1 or scale.dim() == 2" assert b.size(0) == scale.size(0), "b.size(0) == b_scale.size(0)" assert a.size(1) == b.size(1), "a.size(1) == b.size(1)" if scale.dim() == 2: group_size = b.size(1) // scale.size(1) scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape) else: scale_bd = scale.unsqueeze(-1) b1 = b * scale_bd out = torch.einsum('mk,nk->mn', a.to(torch.float), b1.to(torch.float)).to(a.dtype) if bias is not None: out += bias return out def single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, alibi_slopes, window_size_left, window_size_right, softmax_scale, return_lse): q = q.float() k_cache = k_cache.float() v_cache = v_cache.float() def masked_attention(query, key, value, alibi_slope, context_len, window_size_left, window_size_right, qk_scale) -> torch.Tensor: # (num_heads, seq_q, seq_k) qk = torch.einsum('qhd,hkd->hqk', query, key) qk = qk * qk_scale if alibi_slope is not None: alibi_dist = torch.arange(0, context_len, dtype=torch.float32).mlu() alibi = alibi_slope[:, None] * alibi_dist qk = qk + alibi[:, None, :] _, seq_q, seq_k = qk.size() if seq_q > 1: #causal mask ml = torch.zeros((seq_q, seq_k - seq_q), dtype=qk.dtype).mlu() ones = torch.ones((seq_q, seq_q), dtype=qk.dtype).mlu() * -torch.inf mr = torch.triu(ones, diagonal=1) mask = torch.cat((ml, mr), dim=-1) qk = qk + mask if window_size_left != -1 or window_size_right != -1: mask_w = torch.full((seq_q, seq_k), -torch.inf, dtype=torch.float, device="mlu") for qi in range(seq_q): left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0 right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k mask_w[qi, left:right] = 0 qk += mask_w attention = torch.softmax(qk, dim = -1, dtype=qk.dtype) qkv = torch.einsum('hqk,hkd->qhd', attention, value) return qkv, qk if k_cache_quant_scale is not None and v_cache_quant_scale is not None: if k_cache_quant_scale.dim() == 2: # per_channel: [kv_head_num, head_size] k_cache_quant_scale = k_cache_quant_scale.reshape(1, k_cache_quant_scale.shape[0], 1, k_cache_quant_scale.shape[1]) v_cache_quant_scale = v_cache_quant_scale.reshape(1, v_cache_quant_scale.shape[0], 1, v_cache_quant_scale.shape[1]) elif k_cache_quant_scale.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] k_cache_quant_scale = k_cache_quant_scale.reshape(*k_cache_quant_scale.shape, 1) v_cache_quant_scale = v_cache_quant_scale.reshape(*v_cache_quant_scale.shape, 1) k_cache *= k_cache_quant_scale v_cache *= v_cache_quant_scale bs, seq_q, num_heads, head_size = q.size() head_size_v = v_cache.size(-1) num_blocks, num_kv_heads, block_size, _ = k_cache.size() output = torch.zeros((bs, seq_q, num_heads, head_size_v), dtype=torch.float16) lse = torch.zeros((bs, num_heads, seq_q), dtype=torch.float) assert (num_heads % num_kv_heads == 0) head_repeats = num_heads // num_kv_heads for bs_id in range(bs): q_bs = q[bs_id] context_len = int(context_lens[bs_id]) if context_len == 0: output[bs_id] = torch.zeros((seq_q, num_heads, head_size_v), device = q.device, dtype=output.dtype) lse[bs_id] = lse[bs_id].fill_(-float('inf')) else : block_table = block_tables[bs_id] table_end = (context_len + block_size - 1) // block_size block_ids = block_table[0 : table_end] keys, values = k_cache[block_ids], v_cache[block_ids] keys = torch.repeat_interleave(keys, head_repeats, dim=1) keys = keys.transpose(1, 0).contiguous().view(num_heads, -1, head_size) keys = keys[:, 0:context_len, :] values = torch.repeat_interleave(values, head_repeats, dim=1) values = values.transpose(1, 0).contiguous().view(num_heads, -1, head_size_v) values = values[:, 0:context_len, :] alibi_slope = alibi_slopes[bs_id] if alibi_slopes is not None else None qkv, qk= masked_attention(q_bs, keys, values, alibi_slope, context_len, window_size_left, window_size_right, softmax_scale) output[bs_id] = qkv lse[bs_id] = torch.logsumexp(qk, dim = -1) return (output, lse) if return_lse else output def update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs): # only pad is_pack = out.dim() == 3 new_out, new_lse = out.clone(), lse.clone() batch, max_seq_len, block_seq_len = lse.shape[0], lse.shape[-1], block_lse.shape[-1] lse_bsh = lse.transpose(-2, -1).unsqueeze(dim=-1) new_lse_bsh = new_lse.transpose(-2, -1).unsqueeze(dim=-1) block_lse_bsh = block_lse.transpose(-2, -1).unsqueeze(dim=-1) if not is_pack: for i in range(batch): out_seq_offset = 0 if seq_offsets is None else seq_offsets[i] out_i = out[i, out_seq_offset : out_seq_offset + block_seq_len] lse_i = lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] block_out_i = block_out[i, :] block_lse_i = block_lse_bsh[i, :] new_out[i, out_seq_offset : out_seq_offset + block_seq_len] = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i) new_lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] = (lse_i - F.logsigmoid(lse_i - block_lse_i)) else: for i in range(batch): block_i_begin = block_cu_seqs[i] block_i_end = block_cu_seqs[i + 1] block_i_lens = block_i_end - block_i_begin out_i_begin = cu_seqs[i] out_seq_offset = seq_offsets[i] block_out_i = block_out[block_i_begin : block_i_end] block_lse_i = block_lse_bsh[i, 0 : block_i_lens] out_i = out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] lse_i = lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] new_out_i = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i) new_lse_i = (lse_i - F.logsigmoid(lse_i - block_lse_i)) new_out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] = new_out_i new_lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] = new_lse_i return (new_out, new_lse_bsh.squeeze(dim=-1).transpose(-2, -1)) class QuantMatmul(torch.nn.Module): def __init__(self, weight, bias, residual, input_scale, weight_scale, gemm_output_scale, dtype, alpha:float = 1.0, beta:float = 1.0, act_mode:str = 'none') -> None: super().__init__() self.dtype = dtype self.weight = Parameter(weight.type(dtype)) self.input_scale = input_scale self.weight_scale = weight_scale self.gemm_output_scale = gemm_output_scale if bias is not None: self.bias = Parameter(bias) else: self.bias = None if residual is not None: self.residual = Parameter(residual) else: self.residual = None self.alpha = alpha self.beta = beta if act_mode == 'none': self.act = None else: self.act = act_mode_dict[act_mode] # d = (a * b + bias) * alpha + c * beta # output = (input * weight + bias) * alpha + residual * beta def forward(self, input: torch.Tensor) -> torch.Tensor: output = F.linear(input.type(self.dtype), self.weight, self.bias) if self.input_scale is not None: i_scale = self.input_scale.expand(self.weight_scale.shape[0], -1).transpose(0, 1) output = torch.mul(output, i_scale) if self.weight_scale is not None: output = torch.mul(output, self.weight_scale) if self.gemm_output_scale is not None: output = torch.mul(output, self.gemm_output_scale) output = torch.mul(output, self.alpha) if self.residual is not None: residual = torch.mul(self.residual, self.beta) output = torch.add(output, residual) if self.act is not None: output = self.act(output) return output # for multiprocessing def assertTensorsEqual( a, b, prec=None, message='', allow_inf=False, use_MSE=False, use_RAE=False, use_RMA=False): tc = TestCase() if a.dtype == torch.bool: a = a.float() if b.dtype == torch.bool: b = b.float() epsilon = 1.0 / 16384 tc.assertEqual(a.size(), b.size(), message) assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor." if a.numel() > 0: # check that NaNs are in the same locations nan_mask = a != a tc.assertTrue(torch.equal(nan_mask, b != b), message) diff = a - b diff[nan_mask] = 0 a = a.clone() b = b.clone() a[nan_mask] = 0 b[nan_mask] = 0 # inf check if allow_inf=True if allow_inf: inf_mask = (a == float("inf")) | (a == float("-inf")) tc.assertTrue(torch.equal(inf_mask, (b == float("inf")) | (b == float("-inf"))), message) diff[inf_mask] = 0 a[inf_mask] = 0 b[inf_mask] = 0 # TODO: implement abs on CharTensor if diff.is_signed() and 'CharTensor' not in diff.type(): diff = diff.abs() if use_MSE: diff = diff.abs().pow(2).sum() a_pow_sum = a.pow(2).sum() if diff <= (2 * epsilon) * (2 * epsilon): diff = 0.0 if a_pow_sum <= epsilon: a_pow_sum = a_pow_sum + epsilon diff = torch.div(diff, (a_pow_sum * 1.0)) tc.assertLessEqual(diff.sqrt(), prec, message) elif use_RAE: diff = diff.abs().sum() a_sum = a.abs().sum() if a_sum == 0: tc.assertEqual(a, b, message) else: diff = torch.div(diff, a_sum) tc.assertLessEqual(diff, prec, message) elif use_RMA: a_mean = a.abs().mean() b_mean = b.abs().mean() if a_mean == 0: tc.assertEqual(a, b, message) else: diff = torch.div((a_mean - b_mean).abs(), a_mean) tc.assertLessEqual(diff, prec, message) else: max_err = diff.max() tc.assertLessEqual(max_err, prec, message) def setup(rank, world_size, backend='cncl'): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '3458' dist.init_process_group(backend, rank=rank, world_size=world_size) torch_mlu.mlu.set_device(rank) def cleanup(): dist.barrier() dist.destroy_process_group() def generate_token_count(num_expert, total_token_count): token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32) sum = torch.sum(token_count, dim=-1) * 1.0 token_count *= total_token_count / sum.item() token_count = token_count.to(dtype=torch.int32) cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) cusum_token_count[-1] = total_token_count return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1] def generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, quant_mode=False, offline=False, invalid_batch_size=0): q_heads = 1 total_heads = q_heads + num_heads * 2 max_bs = batch_size + 1 context_lens = torch.randint(size=(batch_size, ), low=1, high=cache_memory_len // 2, dtype=torch.int32, device='mlu') max_context_len = context_lens.max().item() max_seq_offset = max_context_len // 3 + 1 cache_bs_id = random.sample([*range(0, batch_size)], batch_size) cache_bs_id = torch.IntTensor(cache_bs_id).mlu() if invalid_batch_size > 0: cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch_size)] = -1 context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset, dtype=torch.int32, device='mlu') cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1, high=(cache_memory_len - max_context_len) // 3 + 1, dtype=torch.int32, device='mlu') cu_context_lens = torch.cumsum(context_lens, dim=-1) cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) total_seqlen = cu_context_lens[-1] if packed > 0: context = torch.randn((total_seqlen, total_heads, head_size), dtype=torch.float, device='mlu') context_seq_offsets = None else: context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size), dtype=torch.float, device='mlu') cu_context_lens = context_lens context = context.to(dtype) key = context[..., q_heads:q_heads + num_heads, :] value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :] cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu') cache_scale = None if quant_mode: cache = (cache - 0.5) * 256 cache = cache.to(torch.int8) if offline: cache_scale = torch.randn((2, cache.shape[2], cache.shape[4]), dtype=torch.float, device='mlu') else: cache_scale = torch.randn((2, max_bs, num_heads, cache_memory_len), dtype=torch.float, device='mlu') else: cache = cache.to(dtype) block_size = 16 if "MLU3" not in torch.mlu.get_device_name() else max_context_len min_blocks = (total_seqlen + block_size - 1) // block_size num_blocks = min(min_blocks + 10, 2 * min_blocks) num_slots = num_blocks * block_size slot_mapping = random.sample(range(num_slots), total_seqlen.item()) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() slot_mapping[-1] = -1 return [key, value, cache[0], cache[1], cu_context_lens, max_context_len, packed > 0, context_seq_offsets, cache_bs_id, cache_seq_offsets, cache_scale, slot_mapping] def fused_moe(hidden_states: torch.Tensor, gating_output: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, bias1: Optional[torch.Tensor], bias2: Optional[torch.Tensor], residual: Optional[torch.Tensor], input_smooth: Optional[torch.Tensor], act_smooth: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], topk: int, renormalized: bool, gated: bool, act_mode: str, start_expert_id: int = 0, block_n: int = 0, cncl_comm: int = 0, w1_quant_flag: Optional[List] = None, w2_quant_flag: Optional[List] = None): dtype = hidden_states.dtype ori_input_shape = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) tokens = hidden_states.size(0) gating_output = gating_output.reshape(-1, gating_output.size(-1)) residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None expert_num = gating_output.size(-1) expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) per_token_sq = False # check quant check_list = [input_smooth, act_smooth, w1_scale, w2_scale] if all(x is not None for x in check_list): per_token_sq = True if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present " "and absent at the same time.") # softmax_topk reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalized) # gen_idx expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num) if per_token_sq: if torch.mlu.get_device_name() == 'MLU370': expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, cusum_token_count, start_expert_id, expert_size) quant_input, input_scale = tmo.moe_quantize(expand_hidden_states, input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size]) else: quant_input, input_scale = tmo.moe_quantize(hidden_states, input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx, cusum_token_count[start_expert_id].unsqueeze(0)) else: expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, cusum_token_count, start_expert_id, expert_size) # group gemm if per_token_sq: gemm1_out = tmo.smooth_quant_group_gemm(quant_input, w1, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag) else: gemm1_out = tmo.group_gemm(expand_hidden_states, w1, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, tokens) # add_bias_active act_out = tmo.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size) if per_token_sq: quant_input, input_scale = tmo.moe_quantize(act_out, act_smooth, None, token_count[start_expert_id:start_expert_id+expert_size]) if cncl_comm > 0: raise ValueError("not support communication and computing fusion currently.") else: if per_token_sq: gemm2_out = tmo.smooth_quant_group_gemm(quant_input, w2, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag) else: gemm2_out = tmo.group_gemm(act_out, w2, token_count[start_expert_id:start_expert_id+expert_size], None, None, None, None, tokens) output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx, residual, cusum_token_count, start_expert_id, expert_size, bias2) return output.reshape(ori_input_shape) def min_mem_size(shape, stride): if stride is None: mem_size = 0 mem_size += shape.numel() else: mem_size = 1 for k,v in zip(shape, stride): mem_size += (k - 1) * v return mem_size def create_tensor(shape, dtype, is_contiguous, device, stride = None, mean=0, var=1, is_uniform=False, low=0, high=1): if is_contiguous: if dtype in (torch.int8, torch.uint8): t = torch.randint(-128, 127, shape, device=device).to(dtype) else: if is_uniform: t = torch.empty(shape, dtype=dtype, device=device).uniform_(low, high) else: t = torch.normal(mean, var, shape, dtype=dtype, device=device) else: mem_size = min_mem_size(shape, stride) if dtype in (torch.int8, torch.uint8): t = torch.randint(-128, 127, (mem_size,), device=device).to(dtype) else: if is_uniform: t = torch.empty((mem_size,), dtype=dtype, device=device).uniform_(low, high) else: t = torch.normal(mean, var, (mem_size,), dtype=dtype, device=device) t = t.as_strided(shape, stride) return t def create_tensor_from_dic(dic:dict, mean=0, var=1, is_uniform=False, low=0, high=1): if dic['data'] is None: return None shape = dic['shape'] dtype = dic['dtype'] is_contiguous = dic['is_contiguous'] device = dic['device'] stride = dic['stride'] return create_tensor(shape, dtype, is_contiguous, device, stride, mean, var, is_uniform, low, high) def create_op_param(dic: dict): if dic['type'] in (list, tuple): return [create_op_param(elem) for elem in dic['data']] if dic['has_compound'] else dic['data'] elif dic['type'] is dict: return {k:create_op_param(v) for k,v in dic['data'].items()} elif dic['type'] is torch.Tensor: if dic['data'] is None: return None else: if dic['dtype'] in (torch.int16, torch.int32, torch.int64): return dic['data'] else: return create_tensor(dic['shape'], dic['dtype'], dic['is_contiguous'], dic['device'], dic['stride']) else: return dic['data']