add ops
This commit is contained in:
6
torch_mlu_ops-v1.3.2/torch_mlu_ops/__init__.py
Normal file
6
torch_mlu_ops-v1.3.2/torch_mlu_ops/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import torch
|
||||
from ._ops import *
|
||||
if torch.__version__ >= '2.3.0':
|
||||
from .abstract import *
|
||||
|
||||
from ._version import __version__
|
||||
2704
torch_mlu_ops-v1.3.2/torch_mlu_ops/_ops.py
Normal file
2704
torch_mlu_ops-v1.3.2/torch_mlu_ops/_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
121
torch_mlu_ops-v1.3.2/torch_mlu_ops/_utils.py
Normal file
121
torch_mlu_ops-v1.3.2/torch_mlu_ops/_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import os
|
||||
import importlib.machinery
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
import sys
|
||||
# sys.setrecursionlimit(10000)
|
||||
|
||||
def enable_gen_case():
|
||||
def check_env_flag(name, default=''):
|
||||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
enable_gen_case = check_env_flag('TMO_GEN_CASE')
|
||||
enalbe_dump_data = check_env_flag('TMO_GEN_CASE_DUMP_DATA')
|
||||
tmo_gen_case_op_name = os.getenv('TMO_GEN_CASE_OP_NAME')
|
||||
op_names = tmo_gen_case_op_name.split(';') if tmo_gen_case_op_name else []
|
||||
tmo_gen_case_path = os.getenv('TMO_GEN_CASE_PATH', os.getcwd())
|
||||
return enable_gen_case, op_names, tmo_gen_case_path, enalbe_dump_data
|
||||
|
||||
ENABLE_GEN_CASE, GEN_CASE_OP_NAMES, GEN_CASE_PATH, ENALBE_DUMP_DATA = enable_gen_case()
|
||||
|
||||
def gen_case_dump(**kwargs):
|
||||
torch.save(kwargs['map'], kwargs['path'])
|
||||
|
||||
def dump_info(value):
|
||||
info = dict()
|
||||
if value.__class__ in (list, tuple):
|
||||
cls_set = {elem.__class__ for elem in value}
|
||||
if {torch.Tensor, list, tuple} & cls_set:
|
||||
temp_li = []
|
||||
for elem in value:
|
||||
temp_li.append(dump_info(elem))
|
||||
info = {'type': value.__class__, 'has_compound': True, 'data': temp_li}
|
||||
else:
|
||||
info = {'type': value.__class__, 'has_compound': False, 'data': value}
|
||||
elif value.__class__ is dict:
|
||||
temp_dic = dict()
|
||||
for k, v in value.items():
|
||||
temp_dic[k] = dump_info(v)
|
||||
info = {'type': value.__class__, 'data': temp_dic}
|
||||
elif value.__class__ is torch.Tensor:
|
||||
if value is None:
|
||||
info = {'type': value.__class__, 'data': value}
|
||||
else:
|
||||
info = {'type': value.__class__,
|
||||
'shape': value.shape,
|
||||
'dtype': value.dtype,
|
||||
'device': value.device,
|
||||
'is_contiguous': value.is_contiguous(),
|
||||
'stride': value.stride(),
|
||||
'data': value if value.dtype in (torch.int16, torch.int32, torch.int64) else ''}
|
||||
else:
|
||||
info = {'type': value.__class__, 'data': value}
|
||||
return info
|
||||
|
||||
def dump_pt_case(func, *args, **kwargs):
|
||||
op_name = func.__name__
|
||||
case_dir = os.path.join(GEN_CASE_PATH, 'tmo_gen_case', op_name)
|
||||
os.makedirs(case_dir, exist_ok=True)
|
||||
time_stamp = int(datetime.now().timestamp()*1e6)
|
||||
case_path = os.path.join(case_dir, op_name + '_' + str(time_stamp) + '.pt')
|
||||
print("[torch_mlu_ops] dump case ====> ", case_path)
|
||||
|
||||
signature = inspect.signature(func)
|
||||
param_objs = [obj for obj in signature.parameters.values()] # signature.parameters是有序的mapping
|
||||
params_map = {"op": op_name, "dump_data": ENALBE_DUMP_DATA}
|
||||
if ENALBE_DUMP_DATA:
|
||||
for obj, value in zip(param_objs[:len(args)], args):
|
||||
params_map[obj.name] = value
|
||||
for obj in param_objs[len(args):]:
|
||||
params_map[obj.name] = obj.default
|
||||
params_map = {**params_map, **kwargs}
|
||||
else:
|
||||
for obj, value in zip(param_objs[:len(args)], args):
|
||||
params_map[obj.name] = dump_info(value)
|
||||
for obj in param_objs[len(args):]:
|
||||
params_map[obj.name] = dump_info(obj.default)
|
||||
for k, v in kwargs.items():
|
||||
if k in signature.parameters.keys():
|
||||
params_map[k] = dump_info(v)
|
||||
gen_case_dump(map = params_map, path = case_path)
|
||||
|
||||
def dump_case(func):
|
||||
def wrapper_dump(*args, **kwargs):
|
||||
dump_pt_case(func, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
def wrapper_nothing(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper_dump if (ENABLE_GEN_CASE and \
|
||||
(GEN_CASE_OP_NAMES == [] or func.__name__ in GEN_CASE_OP_NAMES)) \
|
||||
else wrapper_nothing
|
||||
|
||||
def add_gen_case_decorator(frame):
|
||||
if ENABLE_GEN_CASE:
|
||||
_current_module = inspect.getmodule(frame)
|
||||
members = dict()
|
||||
for k, v in inspect.getmembers(_current_module):
|
||||
members[k] = v
|
||||
|
||||
func_names = members['__all__'] if GEN_CASE_OP_NAMES == [] else GEN_CASE_OP_NAMES
|
||||
for name in func_names:
|
||||
setattr(_current_module, name, dump_case(members[name]))
|
||||
|
||||
dtype2str_map = { torch.half: "half",
|
||||
torch.float16: "half",
|
||||
torch.float: "float",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.int32: "int32",
|
||||
torch.int8: "int8"
|
||||
}
|
||||
|
||||
def torchDtype2Str(torch_dtype: torch.dtype) -> str:
|
||||
assert torch_dtype in dtype2str_map, "unrecognized torch type: {}".format(torch_dtype)
|
||||
return dtype2str_map[torch_dtype]
|
||||
|
||||
|
||||
def get_custom_op_library_path():
|
||||
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
||||
ext_specs = importlib.machinery.FileFinder(os.path.dirname(__file__), loader_details).find_spec("_C")
|
||||
assert ext_specs is not None, "Could not find custom op library"
|
||||
return ext_specs.origin
|
||||
627
torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py
Normal file
627
torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py
Normal file
@@ -0,0 +1,627 @@
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch._custom_ops
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::attention_project")
|
||||
def attention_project_abstract(
|
||||
input: Tensor,
|
||||
q_weight: Tensor,
|
||||
q_bias: Optional[Tensor],
|
||||
k_weight: Optional[Tensor],
|
||||
k_bias: Optional[Tensor],
|
||||
v_weight: Optional[Tensor],
|
||||
v_bias: Optional[Tensor],
|
||||
norm_weight: Optional[Tensor],
|
||||
norm_bias: Optional[Tensor],
|
||||
residual: Optional[Tensor],
|
||||
out_layout: str,
|
||||
head_size: int,
|
||||
eps: float,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
norm_out: bool,
|
||||
) -> List[Tensor]:
|
||||
input_view = input
|
||||
if input.dim() == 2:
|
||||
input_view = input.unsqueeze(0)
|
||||
n = input_view.size(0)
|
||||
t = input_view.size(1)
|
||||
hidden_size_q = q_weight.size(0)
|
||||
hidden_size_k = k_weight.size(0) if k_weight is not None else 0
|
||||
hidden_size_v = v_weight.size(0) if v_weight is not None else 0
|
||||
head_num_q = hidden_size_q // head_size
|
||||
head_num_k = hidden_size_k // head_size
|
||||
head_num_v = hidden_size_v // head_size
|
||||
|
||||
out_q = torch.empty(n, t, hidden_size_q, dtype=input_view.dtype, device=input_view.device)
|
||||
out_k = torch.empty(n, t, hidden_size_k, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None
|
||||
out_v = torch.empty(n, t, hidden_size_v, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None
|
||||
|
||||
if out_layout == "nhtc":
|
||||
out_q = torch.empty(n, head_num_q, t, head_size, dtype=input_view.dtype, device=input_view.device)
|
||||
out_k = torch.empty(n, head_num_k, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None
|
||||
out_v = torch.empty(n, head_num_v, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None
|
||||
|
||||
out_ln = torch.empty_like(input_view) if norm_out else None
|
||||
res = [out_q]
|
||||
if k_weight is not None:
|
||||
res.append(out_k)
|
||||
if v_weight is not None:
|
||||
res.append(out_v)
|
||||
if norm_out:
|
||||
res.append(out_ln)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::ffn")
|
||||
def ffn_abstract(
|
||||
input: Tensor,
|
||||
up_fc_weight: Tensor,
|
||||
up_fc_bias: Optional[Tensor],
|
||||
down_proj_weight: Tensor,
|
||||
down_proj_bias: Optional[Tensor],
|
||||
gate_up_proj_weight: Optional[Tensor],
|
||||
gate_up_proj_bias: Optional[Tensor],
|
||||
layernorm_weight: Optional[Tensor],
|
||||
layernorm_bias: Optional[Tensor],
|
||||
act_mode: str,
|
||||
residual_is: str,
|
||||
eps: float,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
) -> Tensor:
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attention")
|
||||
def flash_attention_abstract(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
out: Tensor,
|
||||
output_lse: Optional[Tensor],
|
||||
cu_seq_lens_q: Optional[Tensor],
|
||||
cu_seq_lens_kv: Optional[Tensor],
|
||||
alibi_slope: Optional[Tensor],
|
||||
attn_bias: Optional[Tensor],
|
||||
k_quant_scale: Optional[Tensor],
|
||||
v_quant_scale: Optional[Tensor],
|
||||
block_tables: Optional[Tensor],
|
||||
max_seq_len_q: int,
|
||||
max_seq_len_kv: int,
|
||||
softmax_scale: float,
|
||||
is_causal: bool,
|
||||
window_size_left: int,
|
||||
window_size_right: int,
|
||||
compute_dtype: str,
|
||||
return_lse: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::single_query_cached_kv_attn")
|
||||
def single_query_cached_kv_attn_abstract(
|
||||
q_ori: Tensor,
|
||||
k_cache: Tensor,
|
||||
v_cache: Tensor,
|
||||
output: Tensor,
|
||||
block_tables: Tensor,
|
||||
context_lens: Tensor,
|
||||
output_lse: Optional[Tensor],
|
||||
k_cache_quant_scale: Optional[Tensor],
|
||||
v_cache_quant_scale: Optional[Tensor],
|
||||
alibi_slopes: Optional[Tensor],
|
||||
max_contxt_len: int,
|
||||
windows_size_left: int,
|
||||
windows_size_right: int,
|
||||
softmax_scale: float,
|
||||
return_lse: bool,
|
||||
kv_cache_quant_bit_size: int
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::apply_rotary")
|
||||
def apply_rotary_abstract(
|
||||
input: Tensor,
|
||||
sin_cache: Tensor,
|
||||
cos_cache: Tensor,
|
||||
position_ids: Optional[Tensor],
|
||||
cu_seqlens: Optional[Tensor],
|
||||
interleaved: bool,
|
||||
discrete: bool,
|
||||
dynamic_ntk: bool,
|
||||
max_seqlen: int,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_linear_cache")
|
||||
def reshape_linear_cache_abstract(
|
||||
key: Tensor,
|
||||
value: Optional[Tensor],
|
||||
key_cache: Tensor,
|
||||
value_cache: Optional[Tensor],
|
||||
context_lengths: Tensor,
|
||||
max_context_len: int,
|
||||
packed: bool,
|
||||
context_seq_offset: Optional[Tensor],
|
||||
cache_bs_id: Optional[Tensor],
|
||||
cache_seqlen_offset: Optional[Tensor],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_paged_cache")
|
||||
def reshape_paged_cache_abstract(
|
||||
k: Tensor,
|
||||
v: Optional[Tensor],
|
||||
k_cache: Tensor,
|
||||
v_cache: Optional[Tensor],
|
||||
slot_mapping: Tensor
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_paged_cache")
|
||||
def quant_to_paged_cache_abstract(
|
||||
k: Tensor,
|
||||
v: Optional[Tensor],
|
||||
k_cache: Tensor,
|
||||
v_cache: Optional[Tensor],
|
||||
k_cache_scale: Tensor,
|
||||
v_cache_scale: Optional[Tensor],
|
||||
slot_mapping: Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_paged_cache")
|
||||
def offline_quant_to_paged_cache_abstract(
|
||||
k: Tensor,
|
||||
v: Optional[Tensor],
|
||||
k_cache_scale: Tensor,
|
||||
v_cache_scale: Optional[Tensor],
|
||||
slot_mapping: Tensor,
|
||||
k_cache: Tensor,
|
||||
v_cache: Optional[Tensor],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_linear_cache")
|
||||
def quant_to_linear_cache_abstract(
|
||||
key: Tensor,
|
||||
value: Optional[Tensor],
|
||||
key_cache: Tensor,
|
||||
value_cache: Optional[Tensor],
|
||||
key_cache_scale: Tensor,
|
||||
value_cache_scale: Optional[Tensor],
|
||||
context_lengths: Tensor,
|
||||
max_context_len: int,
|
||||
packed: bool,
|
||||
context_seq_offset: Optional[Tensor],
|
||||
cache_bs_id: Optional[Tensor],
|
||||
cache_seqlen_offset: Optional[Tensor],
|
||||
quant_bit: int = 8,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_linear_cache")
|
||||
def offline_quant_to_linear_cache_abstract(
|
||||
key: Tensor,
|
||||
value: Optional[Tensor],
|
||||
key_cache: Tensor,
|
||||
value_cache: Optional[Tensor],
|
||||
key_cache_scale: Tensor,
|
||||
value_cache_scale: Optional[Tensor],
|
||||
context_lengths: Tensor,
|
||||
max_context_len: int,
|
||||
quant_mode: int,
|
||||
packed: bool,
|
||||
context_seq_offset: Optional[Tensor],
|
||||
cache_bs_id: Optional[Tensor],
|
||||
cache_seqlen_offset: Optional[Tensor],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::swap_blocks")
|
||||
def swap_blocks_abstract(
|
||||
dst: Tensor, src: Tensor, block_mapping: Dict[int, int]
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks")
|
||||
def copy_blocks_abstract(
|
||||
k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks_out_of_place")
|
||||
def copy_blocks_out_of_place_abstract(k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]) -> (List[Tensor], List[Tensor]):
|
||||
return ([torch.empty_like(k) for k in k_caches], [torch.empty_like(v) for v in v_caches])
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul")
|
||||
def quant_matmul_abstract(
|
||||
a_tensor: Tensor,
|
||||
a_scale: Optional[Tensor],
|
||||
a_zero: Optional[Tensor],
|
||||
b_tensor: Tensor,
|
||||
b_scale: Optional[Tensor],
|
||||
b_zero: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
c_tensor: Optional[Tensor],
|
||||
c_scale: Optional[Tensor],
|
||||
c_zero: Optional[Tensor],
|
||||
gemm_output_scale: Optional[Tensor],
|
||||
gemm_output_zero: Optional[Tensor],
|
||||
data_type: Optional[str],
|
||||
d: Optional[Tensor],
|
||||
quant_algo: str,
|
||||
a_quant_layout: str,
|
||||
b_quant_layout: str,
|
||||
quant_bit_size: int = 8,
|
||||
act_mode: str = "none",
|
||||
use_hp_active: bool = False,
|
||||
act_coef: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
trans_a: bool = False,
|
||||
trans_b: bool = True,
|
||||
) -> Tensor:
|
||||
|
||||
if data_type is None:
|
||||
output_type = a_tensor.dtype
|
||||
elif data_type == "float":
|
||||
output_type = torch.float32
|
||||
elif data_type == "bfloat16":
|
||||
output_type = torch.bfloat16
|
||||
else:
|
||||
output_type = torch.float16
|
||||
return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul_allreduce")
|
||||
def quant_matmul_allreduce_abstract(
|
||||
cncl_comm,
|
||||
a_tensor: torch.Tensor,
|
||||
a_scale: Optional[torch.Tensor],
|
||||
a_zero: Optional[torch.Tensor],
|
||||
b_tensor: torch.Tensor,
|
||||
b_scale: Optional[torch.Tensor],
|
||||
b_zero: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
c_tensor: Optional[torch.Tensor],
|
||||
c_scale: Optional[torch.Tensor],
|
||||
c_zero: Optional[torch.Tensor],
|
||||
gemm_output_scale: Optional[torch.Tensor],
|
||||
gemm_output_zero: Optional[torch.Tensor],
|
||||
data_type: Optional[str],
|
||||
d: Optional[torch.Tensor],
|
||||
quant_algo: str,
|
||||
a_quant_layout: str,
|
||||
b_quant_layout: str,
|
||||
quant_bit_size: int = 8,
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
trans_a: bool = False,
|
||||
trans_b: bool = True,
|
||||
block_m: int = 0
|
||||
) -> torch.Tensor:
|
||||
output_type = torch.float16
|
||||
if data_type == "float":
|
||||
output_type = torch.float32
|
||||
elif data_type == "bfloat16":
|
||||
output_type = torch.bfloat16
|
||||
return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device)
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::active")
|
||||
def active_abstract(input: Tensor,
|
||||
output: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
cusum_token_count: Optional[Tensor],
|
||||
act_mode: str,
|
||||
is_gated: bool,
|
||||
start_expert_id: int = 0,
|
||||
expert_size: int = 0,
|
||||
active_coef: float = 1.0) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::smooth_quant")
|
||||
def smooth_quant_abstract(
|
||||
input: Tensor,
|
||||
input_scale: Tensor,
|
||||
output: Tensor,
|
||||
output_scale: Tensor,
|
||||
input_zero: Optional[Tensor],
|
||||
token_count: Optional[Tensor],
|
||||
gather_index: Optional[Tensor],
|
||||
gather_index_start_position: Optional[Tensor],
|
||||
quant_mode: str,
|
||||
dynamic_quant: bool
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_layernorm")
|
||||
def fused_layernorm_abstract(
|
||||
input: Tensor,
|
||||
output: Tensor,
|
||||
residual: Optional[Tensor],
|
||||
beta: Optional[Tensor],
|
||||
gamma: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
quant_scale: Optional[Tensor],
|
||||
residual_out: Optional[Tensor],
|
||||
smooth_quant_scale: Optional[Tensor],
|
||||
norm_mode: str,
|
||||
eps: float,
|
||||
store_output_before_norm: bool,
|
||||
dynamic_quant: bool,
|
||||
)-> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_moe")
|
||||
def fused_moe_abstract(
|
||||
hidden_states: Tensor,
|
||||
gating_output: Tensor,
|
||||
w1: Tensor,
|
||||
w2: Tensor,
|
||||
bias1: Optional[Tensor],
|
||||
bias2: Optional[Tensor],
|
||||
residual: Optional[Tensor],
|
||||
input_smooth: Optional[Tensor],
|
||||
act_smooth: Optional[Tensor],
|
||||
w1_scale: Optional[Tensor],
|
||||
w2_scale: Optional[Tensor],
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
gated: bool,
|
||||
act_mode: str,
|
||||
start_expert_id: int,
|
||||
block_n: int,
|
||||
cncl_comm: int,
|
||||
w1_quant_flag: Optional[List],
|
||||
w2_quant_flag: Optional[List]
|
||||
) -> Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul")
|
||||
def matmul_abstract(
|
||||
a: Tensor,
|
||||
b: Tensor,
|
||||
d: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
c: Optional[Tensor],
|
||||
data_type: Optional[str],
|
||||
act_mode: str,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
fast_act: bool,
|
||||
approximate: bool,
|
||||
a_scale: float,
|
||||
b_scale: float,
|
||||
trans_a: bool,
|
||||
trans_b: bool
|
||||
) -> Tensor:
|
||||
m = a.size(1) if trans_a else a.size(0)
|
||||
n = b.size(0) if trans_b else b.size(1)
|
||||
if data_type is None:
|
||||
output_type = a.dtype
|
||||
elif data_type == "float":
|
||||
output_type = torch.float32
|
||||
elif data_type == "bfloat16":
|
||||
output_type = torch.bfloat16
|
||||
else:
|
||||
output_type = torch.half
|
||||
return torch.empty(m, n, dtype=output_type, device=a.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::batch_matmul")
|
||||
def batch_matmul_abstract(
|
||||
a: Tensor,
|
||||
b: Tensor,
|
||||
c: Tensor,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
a_scale: float,
|
||||
b_scale: float,
|
||||
trans_a: bool,
|
||||
trans_b: bool
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul_allreduce")
|
||||
def matmul_allreduce_abstract(
|
||||
cncl_comm,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
c: Optional[torch.Tensor] = None,
|
||||
d: Optional[torch.Tensor] = None,
|
||||
alpha: float = 1.0,
|
||||
beta: float = .0,
|
||||
block_m: int = 0
|
||||
) -> Tensor:
|
||||
return torch.empty(a.size(0), b.size(0), dtype=a.dtype, device=a.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::group_gemm")
|
||||
def group_gemm_abstract(
|
||||
a: Tensor,
|
||||
b: Tensor,
|
||||
m_list: Tensor,
|
||||
expand_idx: Optional[Tensor],
|
||||
c: Optional[Tensor],
|
||||
alpha: Optional[Tensor],
|
||||
beta: Optional[Tensor],
|
||||
a_scale: Optional[Tensor],
|
||||
b_scale: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
data_type: Optional[str],
|
||||
quant_flag: Optional[List],
|
||||
b_offset: Optional[Tensor],
|
||||
max_m: int
|
||||
) -> Tensor:
|
||||
if data_type is None:
|
||||
output_type = a.dtype
|
||||
elif data_type == "float":
|
||||
output_type = torch.float32
|
||||
elif data_type == "bfloat16":
|
||||
output_type = torch.bfloat16
|
||||
else:
|
||||
output_type = torch.half
|
||||
total_m = a.size(0) if expand_idx is None else expand_idx.size(0)
|
||||
|
||||
return torch.empty(total_m, b.size(1), dtype=output_type, device=a.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::preload")
|
||||
def preload_abstract(
|
||||
weight: Tensor,
|
||||
size: int
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attn_sq_mm_allreduce")
|
||||
def flash_attn_sq_mm_allreduce_abstract(cncl_comm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
alibi_slope,
|
||||
attn_bias,
|
||||
smooth,
|
||||
weight,
|
||||
weight_scale,
|
||||
bias,
|
||||
max_seq_len_q,
|
||||
max_seq_len_kv,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
compute_dtype,
|
||||
block_seq) -> torch.Tensor:
|
||||
res_q = q.unsqueeze(0) if cu_seq_lens_q is not None else q
|
||||
res_q = res_q.flatten(-2, -1).flatten(0, 1)
|
||||
return torch.empty(res_q.size(0), weight.size(0), dtype=q.dtype, device=q.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_softmax_topk")
|
||||
def moe_softmax_topk_abstract(input,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
normalize,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
normed_by: str = "topk_logit") -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
out_shape = list(input.size())[:-1] + [topk]
|
||||
reduce_weight = torch.empty(out_shape, dtype=torch.float32, device=input.device)
|
||||
expert_id = torch.empty(out_shape, dtype=torch.int, device=input.device)
|
||||
return (reduce_weight, expert_id)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_expand_input")
|
||||
def moe_expand_input_abstract(input: Tensor,
|
||||
gather_idx: Tensor,
|
||||
cusum_token_count: Optional[Tensor] = None,
|
||||
start_expert_id: int = 0,
|
||||
expert_size: int = 0) -> Tensor:
|
||||
return torch.empty(gather_idx.size(0), input.size(-1), dtype=input.dtype, device=input.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_gen_idx")
|
||||
def moe_gen_idx_abstract(expert_id: Tensor,
|
||||
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
token_num, topk = expert_id.size(0), expert_id.size(1)
|
||||
expand_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device)
|
||||
combine_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device)
|
||||
token_count = torch.empty((expert_num,), dtype=torch.int32, device=expert_id.device)
|
||||
cusum_token_count = torch.empty((expert_num + 1,), dtype=torch.int32, device=expert_id.device)
|
||||
return (expand_idx, combine_idx, token_count, cusum_token_count)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_combine_result")
|
||||
def moe_combine_result_abstract(input: torch.Tensor,
|
||||
reduce_weight: torch.Tensor,
|
||||
gather_ids: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
cusum_token_count: Optional[torch.Tensor],
|
||||
start_expert_id: int,
|
||||
expert_size: int,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
num_tokens, hidden_size, topk = input.size(0), input.size(1), reduce_weight.size(1)
|
||||
num_token = num_tokens // topk
|
||||
return torch.empty(num_token, hidden_size, dtype=input.dtype, device=input.device)
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_rope")
|
||||
def fused_rope_abstract(qkv: torch.Tensor,
|
||||
key_cache_hp: torch.Tensor,
|
||||
value_cache_hp: torch.Tensor,
|
||||
key_cache_lp: Optional[torch.Tensor],
|
||||
value_cache_lp: Optional[torch.Tensor],
|
||||
sin_table: torch.Tensor,
|
||||
cos_table: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
key_scale_hp: Optional[torch.Tensor],
|
||||
value_scale_hp: Optional[torch.Tensor],
|
||||
key_scale_lp: Optional[torch.Tensor],
|
||||
value_scale_lp: Optional[torch.Tensor],
|
||||
cache_bs_id_hp: Optional[torch.Tensor],
|
||||
cache_seq_offsets_hp: Optional[torch.Tensor],
|
||||
cache_bs_id_lp: Optional[torch.Tensor],
|
||||
cache_seq_offsets_lp: Optional[torch.Tensor],
|
||||
slot_mapping_hp: Optional[torch.Tensor],
|
||||
slot_mapping_lp: Optional[torch.Tensor],
|
||||
eps: float) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_cast_gating")
|
||||
def moe_cast_gating_abstract(input: torch.Tensor,
|
||||
weight: torch.Tensor) ->Tensor:
|
||||
output_shape = input.shape[:-1] + (weight.shape[0],)
|
||||
output = torch.empty(output_shape, dtype=torch.float, device="mlu")
|
||||
return output
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::update_out_and_lse")
|
||||
def update_out_and_lse_abstract(out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
seq_offsets: Optional[torch.Tensor] = None,
|
||||
cu_seqs: Optional[torch.Tensor] = None,
|
||||
block_cu_seqs: Optional[torch.Tensor] = None) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_linear_cache")
|
||||
def dequant_from_linear_cache_abstract(key: torch.Tensor,
|
||||
value: Optional[torch.Tensor],
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: Optional[torch.Tensor],
|
||||
key_cache_quant_scale: torch.Tensor,
|
||||
value_cache_quant_scale: Optional[torch.Tensor],
|
||||
context_lengths: torch.Tensor,
|
||||
max_context_len: int,
|
||||
context_seq_offset: Optional[torch.Tensor],
|
||||
cache_bs_id: Optional[torch.Tensor],
|
||||
cache_seq_offset: Optional[torch.Tensor],
|
||||
quant_mode: int = 0,
|
||||
quant_bit: int = 8) -> None:
|
||||
return None
|
||||
|
||||
@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_paged_cache")
|
||||
def dequant_from_paged_cache_abstract(key: torch.Tensor,
|
||||
value: Optional[torch.Tensor],
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: Optional[torch.Tensor],
|
||||
key_cache_quant_scale: torch.Tensor,
|
||||
value_cache_quant_scale: Optional[torch.Tensor],
|
||||
context_lengths: torch.Tensor,
|
||||
max_context_len: int,
|
||||
context_seq_offset: Optional[torch.Tensor],
|
||||
block_tables: torch.Tensor,
|
||||
quant_mode: int = 0,
|
||||
quant_bit: int = 8) -> None:
|
||||
return None
|
||||
Reference in New Issue
Block a user