Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py
2026-02-04 17:39:32 +08:00

628 lines
22 KiB
Python

from typing import Tuple, List, Dict, Optional
import torch
from torch import Tensor
import torch._custom_ops
@torch._custom_ops.impl_abstract("torch_mlu_ops::attention_project")
def attention_project_abstract(
input: Tensor,
q_weight: Tensor,
q_bias: Optional[Tensor],
k_weight: Optional[Tensor],
k_bias: Optional[Tensor],
v_weight: Optional[Tensor],
v_bias: Optional[Tensor],
norm_weight: Optional[Tensor],
norm_bias: Optional[Tensor],
residual: Optional[Tensor],
out_layout: str,
head_size: int,
eps: float,
alpha: float,
beta: float,
norm_out: bool,
) -> List[Tensor]:
input_view = input
if input.dim() == 2:
input_view = input.unsqueeze(0)
n = input_view.size(0)
t = input_view.size(1)
hidden_size_q = q_weight.size(0)
hidden_size_k = k_weight.size(0) if k_weight is not None else 0
hidden_size_v = v_weight.size(0) if v_weight is not None else 0
head_num_q = hidden_size_q // head_size
head_num_k = hidden_size_k // head_size
head_num_v = hidden_size_v // head_size
out_q = torch.empty(n, t, hidden_size_q, dtype=input_view.dtype, device=input_view.device)
out_k = torch.empty(n, t, hidden_size_k, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None
out_v = torch.empty(n, t, hidden_size_v, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None
if out_layout == "nhtc":
out_q = torch.empty(n, head_num_q, t, head_size, dtype=input_view.dtype, device=input_view.device)
out_k = torch.empty(n, head_num_k, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None
out_v = torch.empty(n, head_num_v, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None
out_ln = torch.empty_like(input_view) if norm_out else None
res = [out_q]
if k_weight is not None:
res.append(out_k)
if v_weight is not None:
res.append(out_v)
if norm_out:
res.append(out_ln)
return res
@torch._custom_ops.impl_abstract("torch_mlu_ops::ffn")
def ffn_abstract(
input: Tensor,
up_fc_weight: Tensor,
up_fc_bias: Optional[Tensor],
down_proj_weight: Tensor,
down_proj_bias: Optional[Tensor],
gate_up_proj_weight: Optional[Tensor],
gate_up_proj_bias: Optional[Tensor],
layernorm_weight: Optional[Tensor],
layernorm_bias: Optional[Tensor],
act_mode: str,
residual_is: str,
eps: float,
alpha: float,
beta: float,
) -> Tensor:
return torch.empty_like(input)
@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attention")
def flash_attention_abstract(
q: Tensor,
k: Tensor,
v: Tensor,
out: Tensor,
output_lse: Optional[Tensor],
cu_seq_lens_q: Optional[Tensor],
cu_seq_lens_kv: Optional[Tensor],
alibi_slope: Optional[Tensor],
attn_bias: Optional[Tensor],
k_quant_scale: Optional[Tensor],
v_quant_scale: Optional[Tensor],
block_tables: Optional[Tensor],
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
compute_dtype: str,
return_lse: bool,
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::single_query_cached_kv_attn")
def single_query_cached_kv_attn_abstract(
q_ori: Tensor,
k_cache: Tensor,
v_cache: Tensor,
output: Tensor,
block_tables: Tensor,
context_lens: Tensor,
output_lse: Optional[Tensor],
k_cache_quant_scale: Optional[Tensor],
v_cache_quant_scale: Optional[Tensor],
alibi_slopes: Optional[Tensor],
max_contxt_len: int,
windows_size_left: int,
windows_size_right: int,
softmax_scale: float,
return_lse: bool,
kv_cache_quant_bit_size: int
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::apply_rotary")
def apply_rotary_abstract(
input: Tensor,
sin_cache: Tensor,
cos_cache: Tensor,
position_ids: Optional[Tensor],
cu_seqlens: Optional[Tensor],
interleaved: bool,
discrete: bool,
dynamic_ntk: bool,
max_seqlen: int,
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_linear_cache")
def reshape_linear_cache_abstract(
key: Tensor,
value: Optional[Tensor],
key_cache: Tensor,
value_cache: Optional[Tensor],
context_lengths: Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[Tensor],
cache_bs_id: Optional[Tensor],
cache_seqlen_offset: Optional[Tensor],
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_paged_cache")
def reshape_paged_cache_abstract(
k: Tensor,
v: Optional[Tensor],
k_cache: Tensor,
v_cache: Optional[Tensor],
slot_mapping: Tensor
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_paged_cache")
def quant_to_paged_cache_abstract(
k: Tensor,
v: Optional[Tensor],
k_cache: Tensor,
v_cache: Optional[Tensor],
k_cache_scale: Tensor,
v_cache_scale: Optional[Tensor],
slot_mapping: Tensor,
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_paged_cache")
def offline_quant_to_paged_cache_abstract(
k: Tensor,
v: Optional[Tensor],
k_cache_scale: Tensor,
v_cache_scale: Optional[Tensor],
slot_mapping: Tensor,
k_cache: Tensor,
v_cache: Optional[Tensor],
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_linear_cache")
def quant_to_linear_cache_abstract(
key: Tensor,
value: Optional[Tensor],
key_cache: Tensor,
value_cache: Optional[Tensor],
key_cache_scale: Tensor,
value_cache_scale: Optional[Tensor],
context_lengths: Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[Tensor],
cache_bs_id: Optional[Tensor],
cache_seqlen_offset: Optional[Tensor],
quant_bit: int = 8,
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_linear_cache")
def offline_quant_to_linear_cache_abstract(
key: Tensor,
value: Optional[Tensor],
key_cache: Tensor,
value_cache: Optional[Tensor],
key_cache_scale: Tensor,
value_cache_scale: Optional[Tensor],
context_lengths: Tensor,
max_context_len: int,
quant_mode: int,
packed: bool,
context_seq_offset: Optional[Tensor],
cache_bs_id: Optional[Tensor],
cache_seqlen_offset: Optional[Tensor],
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::swap_blocks")
def swap_blocks_abstract(
dst: Tensor, src: Tensor, block_mapping: Dict[int, int]
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks")
def copy_blocks_abstract(
k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks_out_of_place")
def copy_blocks_out_of_place_abstract(k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]) -> (List[Tensor], List[Tensor]):
return ([torch.empty_like(k) for k in k_caches], [torch.empty_like(v) for v in v_caches])
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul")
def quant_matmul_abstract(
a_tensor: Tensor,
a_scale: Optional[Tensor],
a_zero: Optional[Tensor],
b_tensor: Tensor,
b_scale: Optional[Tensor],
b_zero: Optional[Tensor],
bias: Optional[Tensor],
c_tensor: Optional[Tensor],
c_scale: Optional[Tensor],
c_zero: Optional[Tensor],
gemm_output_scale: Optional[Tensor],
gemm_output_zero: Optional[Tensor],
data_type: Optional[str],
d: Optional[Tensor],
quant_algo: str,
a_quant_layout: str,
b_quant_layout: str,
quant_bit_size: int = 8,
act_mode: str = "none",
use_hp_active: bool = False,
act_coef: float = 1.0,
alpha: float = 1.0,
beta: float = 1.0,
trans_a: bool = False,
trans_b: bool = True,
) -> Tensor:
if data_type is None:
output_type = a_tensor.dtype
elif data_type == "float":
output_type = torch.float32
elif data_type == "bfloat16":
output_type = torch.bfloat16
else:
output_type = torch.float16
return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul_allreduce")
def quant_matmul_allreduce_abstract(
cncl_comm,
a_tensor: torch.Tensor,
a_scale: Optional[torch.Tensor],
a_zero: Optional[torch.Tensor],
b_tensor: torch.Tensor,
b_scale: Optional[torch.Tensor],
b_zero: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
c_tensor: Optional[torch.Tensor],
c_scale: Optional[torch.Tensor],
c_zero: Optional[torch.Tensor],
gemm_output_scale: Optional[torch.Tensor],
gemm_output_zero: Optional[torch.Tensor],
data_type: Optional[str],
d: Optional[torch.Tensor],
quant_algo: str,
a_quant_layout: str,
b_quant_layout: str,
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0,
trans_a: bool = False,
trans_b: bool = True,
block_m: int = 0
) -> torch.Tensor:
output_type = torch.float16
if data_type == "float":
output_type = torch.float32
elif data_type == "bfloat16":
output_type = torch.bfloat16
return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::active")
def active_abstract(input: Tensor,
output: Tensor,
bias: Optional[Tensor],
cusum_token_count: Optional[Tensor],
act_mode: str,
is_gated: bool,
start_expert_id: int = 0,
expert_size: int = 0,
active_coef: float = 1.0) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::smooth_quant")
def smooth_quant_abstract(
input: Tensor,
input_scale: Tensor,
output: Tensor,
output_scale: Tensor,
input_zero: Optional[Tensor],
token_count: Optional[Tensor],
gather_index: Optional[Tensor],
gather_index_start_position: Optional[Tensor],
quant_mode: str,
dynamic_quant: bool
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_layernorm")
def fused_layernorm_abstract(
input: Tensor,
output: Tensor,
residual: Optional[Tensor],
beta: Optional[Tensor],
gamma: Optional[Tensor],
bias: Optional[Tensor],
quant_scale: Optional[Tensor],
residual_out: Optional[Tensor],
smooth_quant_scale: Optional[Tensor],
norm_mode: str,
eps: float,
store_output_before_norm: bool,
dynamic_quant: bool,
)-> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_moe")
def fused_moe_abstract(
hidden_states: Tensor,
gating_output: Tensor,
w1: Tensor,
w2: Tensor,
bias1: Optional[Tensor],
bias2: Optional[Tensor],
residual: Optional[Tensor],
input_smooth: Optional[Tensor],
act_smooth: Optional[Tensor],
w1_scale: Optional[Tensor],
w2_scale: Optional[Tensor],
topk: int,
renormalize: bool,
gated: bool,
act_mode: str,
start_expert_id: int,
block_n: int,
cncl_comm: int,
w1_quant_flag: Optional[List],
w2_quant_flag: Optional[List]
) -> Tensor:
return torch.empty_like(hidden_states)
@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul")
def matmul_abstract(
a: Tensor,
b: Tensor,
d: Optional[Tensor],
bias: Optional[Tensor],
c: Optional[Tensor],
data_type: Optional[str],
act_mode: str,
alpha: float,
beta: float,
fast_act: bool,
approximate: bool,
a_scale: float,
b_scale: float,
trans_a: bool,
trans_b: bool
) -> Tensor:
m = a.size(1) if trans_a else a.size(0)
n = b.size(0) if trans_b else b.size(1)
if data_type is None:
output_type = a.dtype
elif data_type == "float":
output_type = torch.float32
elif data_type == "bfloat16":
output_type = torch.bfloat16
else:
output_type = torch.half
return torch.empty(m, n, dtype=output_type, device=a.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::batch_matmul")
def batch_matmul_abstract(
a: Tensor,
b: Tensor,
c: Tensor,
alpha: float,
beta: float,
a_scale: float,
b_scale: float,
trans_a: bool,
trans_b: bool
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul_allreduce")
def matmul_allreduce_abstract(
cncl_comm,
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
d: Optional[torch.Tensor] = None,
alpha: float = 1.0,
beta: float = .0,
block_m: int = 0
) -> Tensor:
return torch.empty(a.size(0), b.size(0), dtype=a.dtype, device=a.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::group_gemm")
def group_gemm_abstract(
a: Tensor,
b: Tensor,
m_list: Tensor,
expand_idx: Optional[Tensor],
c: Optional[Tensor],
alpha: Optional[Tensor],
beta: Optional[Tensor],
a_scale: Optional[Tensor],
b_scale: Optional[Tensor],
bias: Optional[Tensor],
data_type: Optional[str],
quant_flag: Optional[List],
b_offset: Optional[Tensor],
max_m: int
) -> Tensor:
if data_type is None:
output_type = a.dtype
elif data_type == "float":
output_type = torch.float32
elif data_type == "bfloat16":
output_type = torch.bfloat16
else:
output_type = torch.half
total_m = a.size(0) if expand_idx is None else expand_idx.size(0)
return torch.empty(total_m, b.size(1), dtype=output_type, device=a.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::preload")
def preload_abstract(
weight: Tensor,
size: int
) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attn_sq_mm_allreduce")
def flash_attn_sq_mm_allreduce_abstract(cncl_comm,
q,
k,
v,
cu_seq_lens_q,
cu_seq_lens_k,
alibi_slope,
attn_bias,
smooth,
weight,
weight_scale,
bias,
max_seq_len_q,
max_seq_len_kv,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
compute_dtype,
block_seq) -> torch.Tensor:
res_q = q.unsqueeze(0) if cu_seq_lens_q is not None else q
res_q = res_q.flatten(-2, -1).flatten(0, 1)
return torch.empty(res_q.size(0), weight.size(0), dtype=q.dtype, device=q.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_softmax_topk")
def moe_softmax_topk_abstract(input,
topk,
num_expert_group,
topk_group,
normalize,
mask: Optional[torch.Tensor] = None,
normed_by: str = "topk_logit") -> Tuple[torch.Tensor, torch.Tensor]:
out_shape = list(input.size())[:-1] + [topk]
reduce_weight = torch.empty(out_shape, dtype=torch.float32, device=input.device)
expert_id = torch.empty(out_shape, dtype=torch.int, device=input.device)
return (reduce_weight, expert_id)
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_expand_input")
def moe_expand_input_abstract(input: Tensor,
gather_idx: Tensor,
cusum_token_count: Optional[Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> Tensor:
return torch.empty(gather_idx.size(0), input.size(-1), dtype=input.dtype, device=input.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_gen_idx")
def moe_gen_idx_abstract(expert_id: Tensor,
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
token_num, topk = expert_id.size(0), expert_id.size(1)
expand_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device)
combine_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device)
token_count = torch.empty((expert_num,), dtype=torch.int32, device=expert_id.device)
cusum_token_count = torch.empty((expert_num + 1,), dtype=torch.int32, device=expert_id.device)
return (expand_idx, combine_idx, token_count, cusum_token_count)
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_combine_result")
def moe_combine_result_abstract(input: torch.Tensor,
reduce_weight: torch.Tensor,
gather_ids: torch.Tensor,
residual: Optional[torch.Tensor],
cusum_token_count: Optional[torch.Tensor],
start_expert_id: int,
expert_size: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
num_tokens, hidden_size, topk = input.size(0), input.size(1), reduce_weight.size(1)
num_token = num_tokens // topk
return torch.empty(num_token, hidden_size, dtype=input.dtype, device=input.device)
@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_rope")
def fused_rope_abstract(qkv: torch.Tensor,
key_cache_hp: torch.Tensor,
value_cache_hp: torch.Tensor,
key_cache_lp: Optional[torch.Tensor],
value_cache_lp: Optional[torch.Tensor],
sin_table: torch.Tensor,
cos_table: torch.Tensor,
position_ids: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
key_scale_hp: Optional[torch.Tensor],
value_scale_hp: Optional[torch.Tensor],
key_scale_lp: Optional[torch.Tensor],
value_scale_lp: Optional[torch.Tensor],
cache_bs_id_hp: Optional[torch.Tensor],
cache_seq_offsets_hp: Optional[torch.Tensor],
cache_bs_id_lp: Optional[torch.Tensor],
cache_seq_offsets_lp: Optional[torch.Tensor],
slot_mapping_hp: Optional[torch.Tensor],
slot_mapping_lp: Optional[torch.Tensor],
eps: float) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_cast_gating")
def moe_cast_gating_abstract(input: torch.Tensor,
weight: torch.Tensor) ->Tensor:
output_shape = input.shape[:-1] + (weight.shape[0],)
output = torch.empty(output_shape, dtype=torch.float, device="mlu")
return output
@torch._custom_ops.impl_abstract("torch_mlu_ops::update_out_and_lse")
def update_out_and_lse_abstract(out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
seq_offsets: Optional[torch.Tensor] = None,
cu_seqs: Optional[torch.Tensor] = None,
block_cu_seqs: Optional[torch.Tensor] = None) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_linear_cache")
def dequant_from_linear_cache_abstract(key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
key_cache_quant_scale: torch.Tensor,
value_cache_quant_scale: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
context_seq_offset: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor],
cache_seq_offset: Optional[torch.Tensor],
quant_mode: int = 0,
quant_bit: int = 8) -> None:
return None
@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_paged_cache")
def dequant_from_paged_cache_abstract(key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
key_cache_quant_scale: torch.Tensor,
value_cache_quant_scale: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
context_seq_offset: Optional[torch.Tensor],
block_tables: torch.Tensor,
quant_mode: int = 0,
quant_bit: int = 8) -> None:
return None