Files
2026-02-04 17:22:39 +08:00

779 lines
27 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Optional, Tuple
import torch
import math
import triton
import triton.language as tl
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import torch_mlu_ops as tmo
except ImportError as e:
logger.warning("Failed to import from TMO OPS with %r", e)
def rotary_embedding(
input: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
position_ids: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
interleaved: bool,
discrete: bool,
dynamic_ntk: bool,
max_seqlen: int,
) -> torch.Tensor:
return tmo.apply_rotary(
input, sin_cache, cos_cache,
position_ids, cu_seqlens, interleaved,
discrete, dynamic_ntk, max_seqlen)
def fused_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
return tmo.fused_rms_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
None, dynamic_quant)
def fused_layer_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
return tmo.fused_layer_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
None, dynamic_quant)
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_kv: torch.Tensor,
alibi_slope: torch.Tensor,
attn_bias: torch.Tensor,
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int = -1,
window_size_right: int = -1,
compute_dtype: torch.dtype = torch.float,
return_lse: bool = False,
block_tables: torch.Tensor = None,
k_cache_quant_scale: torch.Tensor = None,
v_cache_quant_scale: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
return tmo.flash_attention(
q, k, v, out,
cu_seq_lens_q, cu_seq_lens_kv,
alibi_slope, attn_bias,
max_seq_len_q, max_seq_len_kv,
softmax_scale, is_causal,
window_size_left, window_size_right,
compute_dtype, return_lse,
block_tables, k_cache_quant_scale,
v_cache_quant_scale)
def split_head_nums(q_head_num, kv_head_num, max_q_head_num):
"""
对 q_head_num 进行切分,使得:
1. 切分后的 q_head_num 最大值不超过 max_q_head_num
2. kv_head_num 按 q_head_num 相同份数拆分;
3. 每个切分后的 q_head_num 可以被对应的 kv_head_num 整除;
4. 若 kv_head_num < 1则调整为 1。
参数:
- q_head_num: int, 需要切分的 q_head_num。
- kv_head_num: int, 需要切分的 kv_head_num。
- max_q_head_num: int, 支持切分后最大的 q_head_num
返回:
- q_splits: list, 切分后的 q_head_num。
- kv_splits: list, 切分后的 kv_head_num。
"""
if q_head_num <= 0 or kv_head_num <= 0:
return "q_head_num 和 kv_head_num 必须是正整数!"
q_splits = []
kv_splits = []
# 剩余值
remaining_q = q_head_num
remaining_kv = kv_head_num
while remaining_q > 0:
# 尝试切分 q_head_num最大值不超过 max_q_head_num
for q_part in range(min(max_q_head_num, remaining_q), 0, -1):
# 确保 q_part 能被分配并且对应的 kv_part >= 1
if remaining_q % q_part == 0:
kv_part = max(remaining_kv // (remaining_q // q_part), 1) # 确保 kv_part >= 1
if q_part % kv_part == 0: # 确保 q_part 可以被 kv_part 整除
# 记录切分值
q_splits.append(q_part)
kv_splits.append(kv_part)
remaining_q -= q_part
remaining_kv -= kv_part
break
else:
err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}"
raise RuntimeError(err_msg)
return q_splits, kv_splits
def repeat_elements(input_list, n):
"""
将列表的每个成员连续重复 n 次。
参数:
- input_list: list输入的列表。
- n: int每个元素需要重复的次数。
返回:
- list包含重复元素的新列表。
"""
if not isinstance(input_list, list) or not isinstance(n, int) or n < 0:
raise ValueError("输入必须是一个列表,并且重复次数 n 必须是大于或等于 0 的整数。")
# 使用列表推导式重复每个元素 n 次
return [item for item in input_list for _ in range(n)]
def single_query_cached_kv_attn(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
k_cache_quant_scale: Optional[torch.Tensor],
v_cache_quant_scale: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
max_contxt_len: int,
windows_size_left: int,
windows_size_right: int,
softmax_scale: float,
q_head_dim: Optional[int] = 2,
kv_head_dim: Optional[int] = 1,
seq_q_dim: Optional[int] = 1,
max_seq_q_mul_q_divide_kv: Optional[int] = 48,
) -> None:
# FIXME(chenxiaobing): TMO only support windows_size_right = -1 yet.
windows_size_right = -1
# singleQwithkvCache limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now.
# When the limitation is fixed, we should delete the split process.
seq_q = q.shape[seq_q_dim]
q_head_num = q.shape[q_head_dim]
kv_head_num = k_cache.shape[kv_head_dim]
q_divide_kv = q_head_num // kv_head_num
if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv:
tmo.single_query_cached_kv_attn(
q, k_cache, v_cache, out,
block_tables, context_lens,
k_cache_quant_scale, v_cache_quant_scale,
alibi_slopes, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale)
else:
max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q
q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num)
parts_num = len(q_head_num_sizes)
q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim)
out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim)
alibi_slopes_parts = [None] * parts_num
if alibi_slopes:
alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0)
kv_parts_num = parts_num
if parts_num > kv_head_num:
assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num"
kv_parts_num = kv_head_num
kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num]
if len(kv_head_num_sizes) > 1:
k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim)
v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim)
k_cache_quant_scale_parts = [None] * kv_parts_num
v_cache_quant_scale_parts = [None] * kv_parts_num
if k_cache_quant_scale:
k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim
k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim)
if v_cache_quant_scale:
v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim
v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim)
else:
k_cache_parts = [k_cache]
v_cache_parts = [v_cache]
k_cache_quant_scale_parts = [k_cache_quant_scale]
v_cache_quant_scale_parts = [v_cache_quant_scale]
if parts_num > kv_parts_num:
repeate_num = parts_num // kv_parts_num
k_cache_parts = repeat_elements(k_cache_parts, repeate_num)
v_cache_parts = repeat_elements(v_cache_parts, repeate_num)
k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num)
v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num)
for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip(
q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts,
alibi_slopes_parts):
tmo.single_query_cached_kv_attn(
q_value, k_cache_value.contiguous(), v_cache_value.contiguous(), out_value,
block_tables, context_lens,
k_cache_quant_scale_value, v_cache_quant_scale_value,
alibi_slopes_value, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale)
def reshape_linear_cache(
key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor],
cache_seqlen_offset: Optional[torch.Tensor],
) -> None:
tmo.reshape_linear_cache(
key, value,
key_cache, value_cache,
context_lengths, max_context_len,
packed, context_seq_offset,
cache_bs_id, cache_seqlen_offset)
def reshape_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor
) -> None:
tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
def swap_blocks(
dst: torch.Tensor,
src: torch.Tensor,
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
block_mapping = {src: dst for src, dst in block_mapping}
return tmo.swap_blocks(dst, src, block_mapping)
def copy_blocks(
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
result_dict = {}
for row in block_mapping:
key = row[0]
values = row[1:]
if key in result_dict:
result_dict[key].extend(values)
else:
result_dict[key] = values
return tmo.copy_blocks(k_caches, v_caches, result_dict)
def ffn(
input: torch.Tensor,
up_fc_weight: torch.Tensor,
up_fc_bias: Optional[torch.Tensor],
down_proj_weight: torch.Tensor,
down_proj_bias: Optional[torch.Tensor],
gate_up_proj_weight: Optional[torch.Tensor] = None,
gate_up_proj_bias: Optional[torch.Tensor] = None,
act_mode: str = "none"
) -> torch.Tensor:
return tmo.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias,
gate_up_proj_weight, gate_up_proj_bias, act_mode)
def active(
input: torch.Tensor,
act_mode: str,
is_gated: bool
) -> torch.Tensor:
return tmo.active(input, act_mode, is_gated)
def fused_moe(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
input_smooth: Optional[torch.Tensor],
act_smooth: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
topk: int,
renormalize: bool,
gated: bool,
act_mode: str,
start_expert_id: int = 0,
block_n: int = 0,
cncl_comm: int = 0
) -> torch.Tensor:
return tmo.fused_moe(
hidden_states, gating_output,
w1, w2, bias1, bias2, residual,
input_smooth, act_smooth,
w1_scale, w2_scale, topk,
renormalize, gated, act_mode, start_expert_id,
block_n, cncl_comm)
def matmul(
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
act_mode: str = 'none',
alpha: float = 1.0,
beta: float = .0
) -> torch.Tensor:
return tmo.matmul(a, b, bias, c, act_mode, alpha, beta)
def weight_only_quant_matmul(
a: torch.Tensor,
b: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0
) -> torch.Tensor:
return tmo.weight_only_quant_matmul(
a, b,
scale, zero, bias, c,
act_mode, quant_bit_size, alpha, beta)
def smooth_quant_matmul(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
alpha: float = 1.0,
beta: float = 1.0
) -> torch.Tensor:
return tmo.smooth_quant_matmul(
a, a_scale,
b, b_scale,
dtype, bias, c,
act_mode, alpha, beta)
def per_token_smooth_quantize(
x: torch.Tensor,
smooth: torch.Tensor,
zero: torch.Tensor = None,
token_count: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.per_token_smooth_quantize(x, smooth, zero, token_count)
def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None
) -> torch.Tensor:
return tmo.quantize(x, scale, zero)
def quant_to_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_cache_quant_scale: torch.Tensor,
v_cache_quant_scale: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
return tmo.quant_to_paged_cache(
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping
)
def quant_to_linear_cache(
key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
key_cache_quant_scale: torch.Tensor,
value_cache_quant_scale: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor],
cache_seqlen_offset: Optional[torch.Tensor],
) -> None:
return tmo.quant_to_linear_cache(
key,
value,
key_cache,
value_cache,
key_cache_quant_scale,
value_cache_quant_scale,
context_lengths,
max_context_len,
packed,
context_seq_offset,
cache_bs_id,
cache_seqlen_offset,
)
def advance_step(num_seqs: int,
num_queries: int,
block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
TILE_SIZE: int = 64) -> None:
"""
Advance a step on MLU for existing inputs for a multi-step runner, which
will update input_tokens/seq_lens/input_positions/slot_mapping inplace.
"""
def verify_tensor(
name: str,
tensor: torch.Tensor,
size_0: int,
size_1: int,
dtype: torch.dtype,
):
"""
Auxiliary function to check whether input is valid.
"""
size_0_cond = (size_0 == -1 or tensor.size(0) == size_0)
size_1_cond = (size_1 == -1 or tensor.size(1) == size_1)
if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype):
raise ValueError(
f"The input to advance_step is invalid with tensor name = {name}, "
f"shape = {tensor.shape}, "
f"is_cont = {tensor.is_contiguous()}, "
f"type = {tensor.dtype}, "
f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}"
)
@triton.jit
def _triton_advance_step(input_tokens_ptr,
sampled_token_ids_ptr,
input_positions_ptr,
seq_lens_ptr,
slot_mapping_ptr,
block_tables_ptr,
block_tables_stride,
num_seqs,
num_queries,
block_size,
TILE_SIZE: tl.constexpr,
):
"""
The triton implementation of advance step.
Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55
"""
# Set meta info.
pid = tl.program_id(axis=0)
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
mask = offsets < num_queries
# Update input_tokens.
sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask)
tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask)
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
next_seq_lens = seq_lens + 1
next_input_pos = next_seq_lens - 1
# Update seq_lens.
tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask)
# Update input_positions.
tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask)
# Calculate slot num.
block_index = next_input_pos // block_size
block_offset = next_input_pos % block_size
block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask)
slot_num = block_tables * block_size + block_offset
# Update slot_mapping.
tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask)
verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64)
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64)
verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32)
verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32)
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32)
verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32)
grid = (math.ceil(num_queries / TILE_SIZE), )
_triton_advance_step[grid](input_tokens,
sampled_token_ids,
input_positions,
seq_lens,
slot_mapping,
block_tables,
block_tables.stride(0),
num_seqs,
num_queries,
block_size,
TILE_SIZE)
def preload(
weight: torch.Tensor,
size: int
) -> None:
"""
Preload weights of layer.
Args:
weight (torch.Tensor): Weight to preload。
size (int): Preload size (byte)。
Returns:
None
"""
return tmo.preload(weight, size)
def matmul_allreduce(
cncl_comm,
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
alpha: float = 1.0,
beta: float = .0,
block_m: int = 0
) -> torch.Tensor:
return tmo.matmul_allreduce(cncl_comm=cncl_comm,
a=a, b=b,
bias=bias, c=c,
alpha=alpha,
beta=beta,
block_m=block_m)
def smooth_quant_matmul_allreduce(
cncl_comm,
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
alpha: float = 1.0,
beta: float = 1.0,
block_m: int = 0):
return tmo.smooth_quant_matmul_allreduce(
cncl_comm=cncl_comm,
a=a, a_scale=a_scale,
b=b, b_scale=b_scale,
dtype=dtype, bias=bias, c=c,
alpha=alpha, beta=beta, block_m=block_m)
def quant_matmul_allreduce(
cncl_comm,
a_tensor: torch.Tensor,
a_scale: Optional[torch.Tensor],
a_zero: Optional[torch.Tensor],
b_tensor: torch.Tensor,
b_scale: Optional[torch.Tensor],
b_zero: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
c_tensor: Optional[torch.Tensor],
c_scale: Optional[torch.Tensor],
c_zero: Optional[torch.Tensor],
gemm_output_scale: Optional[torch.Tensor],
gemm_output_zero: Optional[torch.Tensor],
data_type: Optional[str],
quant_algo: str,
a_quant_layout: str,
b_quant_layout: str,
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0,
trans_a: bool = False,
trans_b: bool = True,
block_m: int = 0
) -> torch.Tensor:
return tmo.quant_matmul_allreduce(
cncl_comm=cncl_comm, a_tensor=a_tensor, a_scale=a_scale, a_zero=a_zero,
b_tensor=b_tensor, b_scale=b_scale, b_zero=b_zero, bias=bias,
c_tensor=c_tensor, c_scale=c_scale, c_zero=c_zero,
gemm_output_scale=gemm_output_scale, gemm_output_zero=gemm_output_zero,
data_type=data_type, quant_algo=quant_algo,
a_quant_layout=a_quant_layout, b_quant_layout=b_quant_layout,
quant_bit_size=quant_bit_size,
alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, block_m=block_m)
def flash_attn_sq_mm_allreduce(
cncl_comm: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seq_lens_q: Optional[torch.Tensor],
cu_seq_lens_kv: Optional[torch.Tensor],
alibi_slope: Optional[torch.Tensor],
attn_bias: Optional[torch.Tensor],
smooth: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor],
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int = -1,
window_size_right: int = -1,
compute_dtype: torch.dtype = torch.float,
block_seq: int = 0) -> torch.Tensor:
return tmo.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v,
cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, smooth, weight, weight_scale,
bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left,
window_size_right, compute_dtype, block_seq)
#Moe inner kernels
def moe_softmax_topk(input: torch.Tensor,
topk: int,
normalize: bool = False,
num_expert_group: int = -1,
topk_group: int = 0) -> Tuple[torch.Tensor]:
return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group)
def moe_gen_idx(expert_id: torch.Tensor,
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return tmo.moe_gen_idx(expert_id, expert_num)
def moe_expand_input(input: torch.Tensor,
gather_idx: torch.Tensor,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_expand_input(input, gather_idx,
cusum_token_count,
start_expert_id, expert_size)
def moe_active(input: torch.Tensor,
act_mode: str,
is_gated: bool,
output: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_active(input, act_mode, is_gated, output,
bias, cusum_token_count,
start_expert_id, expert_size)
def group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
max_m: int = 0
) -> torch.Tensor:
return tmo.group_gemm(a, b, m_list, expand_idx,
c, alpha, beta, max_m)
def smooth_quant_group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
a_scale: torch.Tensor,
b_scale: torch.Tensor,
dtype,
max_m: int = 0
) -> torch.Tensor:
return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta,
a_scale, b_scale, dtype, max_m)
def moe_combine_result(input: torch.Tensor,
reduce_weight: torch.Tensor,
gather_ids: torch.Tensor,
residual: Optional[torch.Tensor],
cusum_token_count: Optional[torch.Tensor],
start_expert_id: int,
expert_size: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return tmo.moe_combine_result(input, reduce_weight, gather_ids,
residual, cusum_token_count,
start_expert_id, expert_size, bias)
def moe_quantize(x: torch.Tensor,
smooth: torch.Tensor,
zero: Optional[torch.Tensor] = None,
token_count: Optional[torch.Tensor] = None,
gather_index: Optional[torch.Tensor] = None,
gather_index_start_position: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
dynamic_quant: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position,
output, output_scale, dynamic_quant)