Files

779 lines
27 KiB
Python
Raw Permalink Normal View History

2026-02-04 17:22:39 +08:00
from typing import List, Optional, Tuple
import torch
import math
import triton
import triton.language as tl
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import torch_mlu_ops as tmo
except ImportError as e:
logger.warning("Failed to import from TMO OPS with %r", e)
def rotary_embedding(
input: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
position_ids: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
interleaved: bool,
discrete: bool,
dynamic_ntk: bool,
max_seqlen: int,
) -> torch.Tensor:
return tmo.apply_rotary(
input, sin_cache, cos_cache,
position_ids, cu_seqlens, interleaved,
discrete, dynamic_ntk, max_seqlen)
def fused_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
return tmo.fused_rms_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
None, dynamic_quant)
def fused_layer_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
return tmo.fused_layer_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
None, dynamic_quant)
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_kv: torch.Tensor,
alibi_slope: torch.Tensor,
attn_bias: torch.Tensor,
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int = -1,
window_size_right: int = -1,
compute_dtype: torch.dtype = torch.float,
return_lse: bool = False,
block_tables: torch.Tensor = None,
k_cache_quant_scale: torch.Tensor = None,
v_cache_quant_scale: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
return tmo.flash_attention(
q, k, v, out,
cu_seq_lens_q, cu_seq_lens_kv,
alibi_slope, attn_bias,
max_seq_len_q, max_seq_len_kv,
softmax_scale, is_causal,
window_size_left, window_size_right,
compute_dtype, return_lse,
block_tables, k_cache_quant_scale,
v_cache_quant_scale)
def split_head_nums(q_head_num, kv_head_num, max_q_head_num):
"""
q_head_num 进行切分使得
1. 切分后的 q_head_num 最大值不超过 max_q_head_num
2. kv_head_num q_head_num 相同份数拆分
3. 每个切分后的 q_head_num 可以被对应的 kv_head_num 整除
4. kv_head_num < 1则调整为 1
参数
- q_head_num: int, 需要切分的 q_head_num
- kv_head_num: int, 需要切分的 kv_head_num
- max_q_head_num: int, 支持切分后最大的 q_head_num
返回
- q_splits: list, 切分后的 q_head_num
- kv_splits: list, 切分后的 kv_head_num
"""
if q_head_num <= 0 or kv_head_num <= 0:
return "q_head_num 和 kv_head_num 必须是正整数!"
q_splits = []
kv_splits = []
# 剩余值
remaining_q = q_head_num
remaining_kv = kv_head_num
while remaining_q > 0:
# 尝试切分 q_head_num最大值不超过 max_q_head_num
for q_part in range(min(max_q_head_num, remaining_q), 0, -1):
# 确保 q_part 能被分配并且对应的 kv_part >= 1
if remaining_q % q_part == 0:
kv_part = max(remaining_kv // (remaining_q // q_part), 1) # 确保 kv_part >= 1
if q_part % kv_part == 0: # 确保 q_part 可以被 kv_part 整除
# 记录切分值
q_splits.append(q_part)
kv_splits.append(kv_part)
remaining_q -= q_part
remaining_kv -= kv_part
break
else:
err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}"
raise RuntimeError(err_msg)
return q_splits, kv_splits
def repeat_elements(input_list, n):
"""
将列表的每个成员连续重复 n
参数
- input_list: list输入的列表
- n: int每个元素需要重复的次数
返回
- list包含重复元素的新列表
"""
if not isinstance(input_list, list) or not isinstance(n, int) or n < 0:
raise ValueError("输入必须是一个列表,并且重复次数 n 必须是大于或等于 0 的整数。")
# 使用列表推导式重复每个元素 n 次
return [item for item in input_list for _ in range(n)]
def single_query_cached_kv_attn(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
k_cache_quant_scale: Optional[torch.Tensor],
v_cache_quant_scale: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
max_contxt_len: int,
windows_size_left: int,
windows_size_right: int,
softmax_scale: float,
q_head_dim: Optional[int] = 2,
kv_head_dim: Optional[int] = 1,
seq_q_dim: Optional[int] = 1,
max_seq_q_mul_q_divide_kv: Optional[int] = 48,
) -> None:
# FIXME(chenxiaobing): TMO only support windows_size_right = -1 yet.
windows_size_right = -1
# singleQwithkvCache limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now.
# When the limitation is fixed, we should delete the split process.
seq_q = q.shape[seq_q_dim]
q_head_num = q.shape[q_head_dim]
kv_head_num = k_cache.shape[kv_head_dim]
q_divide_kv = q_head_num // kv_head_num
if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv:
tmo.single_query_cached_kv_attn(
q, k_cache, v_cache, out,
block_tables, context_lens,
k_cache_quant_scale, v_cache_quant_scale,
alibi_slopes, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale)
else:
max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q
q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num)
parts_num = len(q_head_num_sizes)
q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim)
out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim)
alibi_slopes_parts = [None] * parts_num
if alibi_slopes:
alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0)
kv_parts_num = parts_num
if parts_num > kv_head_num:
assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num"
kv_parts_num = kv_head_num
kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num]
if len(kv_head_num_sizes) > 1:
k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim)
v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim)
k_cache_quant_scale_parts = [None] * kv_parts_num
v_cache_quant_scale_parts = [None] * kv_parts_num
if k_cache_quant_scale:
k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim
k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim)
if v_cache_quant_scale:
v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim
v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim)
else:
k_cache_parts = [k_cache]
v_cache_parts = [v_cache]
k_cache_quant_scale_parts = [k_cache_quant_scale]
v_cache_quant_scale_parts = [v_cache_quant_scale]
if parts_num > kv_parts_num:
repeate_num = parts_num // kv_parts_num
k_cache_parts = repeat_elements(k_cache_parts, repeate_num)
v_cache_parts = repeat_elements(v_cache_parts, repeate_num)
k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num)
v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num)
for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip(
q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts,
alibi_slopes_parts):
tmo.single_query_cached_kv_attn(
q_value, k_cache_value.contiguous(), v_cache_value.contiguous(), out_value,
block_tables, context_lens,
k_cache_quant_scale_value, v_cache_quant_scale_value,
alibi_slopes_value, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale)
def reshape_linear_cache(
key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor],
cache_seqlen_offset: Optional[torch.Tensor],
) -> None:
tmo.reshape_linear_cache(
key, value,
key_cache, value_cache,
context_lengths, max_context_len,
packed, context_seq_offset,
cache_bs_id, cache_seqlen_offset)
def reshape_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor
) -> None:
tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
def swap_blocks(
dst: torch.Tensor,
src: torch.Tensor,
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
block_mapping = {src: dst for src, dst in block_mapping}
return tmo.swap_blocks(dst, src, block_mapping)
def copy_blocks(
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
result_dict = {}
for row in block_mapping:
key = row[0]
values = row[1:]
if key in result_dict:
result_dict[key].extend(values)
else:
result_dict[key] = values
return tmo.copy_blocks(k_caches, v_caches, result_dict)
def ffn(
input: torch.Tensor,
up_fc_weight: torch.Tensor,
up_fc_bias: Optional[torch.Tensor],
down_proj_weight: torch.Tensor,
down_proj_bias: Optional[torch.Tensor],
gate_up_proj_weight: Optional[torch.Tensor] = None,
gate_up_proj_bias: Optional[torch.Tensor] = None,
act_mode: str = "none"
) -> torch.Tensor:
return tmo.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias,
gate_up_proj_weight, gate_up_proj_bias, act_mode)
def active(
input: torch.Tensor,
act_mode: str,
is_gated: bool
) -> torch.Tensor:
return tmo.active(input, act_mode, is_gated)
def fused_moe(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
input_smooth: Optional[torch.Tensor],
act_smooth: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
topk: int,
renormalize: bool,
gated: bool,
act_mode: str,
start_expert_id: int = 0,
block_n: int = 0,
cncl_comm: int = 0
) -> torch.Tensor:
return tmo.fused_moe(
hidden_states, gating_output,
w1, w2, bias1, bias2, residual,
input_smooth, act_smooth,
w1_scale, w2_scale, topk,
renormalize, gated, act_mode, start_expert_id,
block_n, cncl_comm)
def matmul(
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
act_mode: str = 'none',
alpha: float = 1.0,
beta: float = .0
) -> torch.Tensor:
return tmo.matmul(a, b, bias, c, act_mode, alpha, beta)
def weight_only_quant_matmul(
a: torch.Tensor,
b: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0
) -> torch.Tensor:
return tmo.weight_only_quant_matmul(
a, b,
scale, zero, bias, c,
act_mode, quant_bit_size, alpha, beta)
def smooth_quant_matmul(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
alpha: float = 1.0,
beta: float = 1.0
) -> torch.Tensor:
return tmo.smooth_quant_matmul(
a, a_scale,
b, b_scale,
dtype, bias, c,
act_mode, alpha, beta)
def per_token_smooth_quantize(
x: torch.Tensor,
smooth: torch.Tensor,
zero: torch.Tensor = None,
token_count: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.per_token_smooth_quantize(x, smooth, zero, token_count)
def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None
) -> torch.Tensor:
return tmo.quantize(x, scale, zero)
def quant_to_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_cache_quant_scale: torch.Tensor,
v_cache_quant_scale: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
return tmo.quant_to_paged_cache(
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping
)
def quant_to_linear_cache(
key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
key_cache_quant_scale: torch.Tensor,
value_cache_quant_scale: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
packed: bool,
context_seq_offset: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor],
cache_seqlen_offset: Optional[torch.Tensor],
) -> None:
return tmo.quant_to_linear_cache(
key,
value,
key_cache,
value_cache,
key_cache_quant_scale,
value_cache_quant_scale,
context_lengths,
max_context_len,
packed,
context_seq_offset,
cache_bs_id,
cache_seqlen_offset,
)
def advance_step(num_seqs: int,
num_queries: int,
block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
TILE_SIZE: int = 64) -> None:
"""
Advance a step on MLU for existing inputs for a multi-step runner, which
will update input_tokens/seq_lens/input_positions/slot_mapping inplace.
"""
def verify_tensor(
name: str,
tensor: torch.Tensor,
size_0: int,
size_1: int,
dtype: torch.dtype,
):
"""
Auxiliary function to check whether input is valid.
"""
size_0_cond = (size_0 == -1 or tensor.size(0) == size_0)
size_1_cond = (size_1 == -1 or tensor.size(1) == size_1)
if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype):
raise ValueError(
f"The input to advance_step is invalid with tensor name = {name}, "
f"shape = {tensor.shape}, "
f"is_cont = {tensor.is_contiguous()}, "
f"type = {tensor.dtype}, "
f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}"
)
@triton.jit
def _triton_advance_step(input_tokens_ptr,
sampled_token_ids_ptr,
input_positions_ptr,
seq_lens_ptr,
slot_mapping_ptr,
block_tables_ptr,
block_tables_stride,
num_seqs,
num_queries,
block_size,
TILE_SIZE: tl.constexpr,
):
"""
The triton implementation of advance step.
Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55
"""
# Set meta info.
pid = tl.program_id(axis=0)
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
mask = offsets < num_queries
# Update input_tokens.
sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask)
tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask)
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
next_seq_lens = seq_lens + 1
next_input_pos = next_seq_lens - 1
# Update seq_lens.
tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask)
# Update input_positions.
tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask)
# Calculate slot num.
block_index = next_input_pos // block_size
block_offset = next_input_pos % block_size
block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask)
slot_num = block_tables * block_size + block_offset
# Update slot_mapping.
tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask)
verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64)
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64)
verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32)
verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32)
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32)
verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32)
grid = (math.ceil(num_queries / TILE_SIZE), )
_triton_advance_step[grid](input_tokens,
sampled_token_ids,
input_positions,
seq_lens,
slot_mapping,
block_tables,
block_tables.stride(0),
num_seqs,
num_queries,
block_size,
TILE_SIZE)
def preload(
weight: torch.Tensor,
size: int
) -> None:
"""
Preload weights of layer.
Args:
weight (torch.Tensor): Weight to preload
size (int): Preload size (byte)
Returns:
None
"""
return tmo.preload(weight, size)
def matmul_allreduce(
cncl_comm,
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
alpha: float = 1.0,
beta: float = .0,
block_m: int = 0
) -> torch.Tensor:
return tmo.matmul_allreduce(cncl_comm=cncl_comm,
a=a, b=b,
bias=bias, c=c,
alpha=alpha,
beta=beta,
block_m=block_m)
def smooth_quant_matmul_allreduce(
cncl_comm,
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
alpha: float = 1.0,
beta: float = 1.0,
block_m: int = 0):
return tmo.smooth_quant_matmul_allreduce(
cncl_comm=cncl_comm,
a=a, a_scale=a_scale,
b=b, b_scale=b_scale,
dtype=dtype, bias=bias, c=c,
alpha=alpha, beta=beta, block_m=block_m)
def quant_matmul_allreduce(
cncl_comm,
a_tensor: torch.Tensor,
a_scale: Optional[torch.Tensor],
a_zero: Optional[torch.Tensor],
b_tensor: torch.Tensor,
b_scale: Optional[torch.Tensor],
b_zero: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
c_tensor: Optional[torch.Tensor],
c_scale: Optional[torch.Tensor],
c_zero: Optional[torch.Tensor],
gemm_output_scale: Optional[torch.Tensor],
gemm_output_zero: Optional[torch.Tensor],
data_type: Optional[str],
quant_algo: str,
a_quant_layout: str,
b_quant_layout: str,
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0,
trans_a: bool = False,
trans_b: bool = True,
block_m: int = 0
) -> torch.Tensor:
return tmo.quant_matmul_allreduce(
cncl_comm=cncl_comm, a_tensor=a_tensor, a_scale=a_scale, a_zero=a_zero,
b_tensor=b_tensor, b_scale=b_scale, b_zero=b_zero, bias=bias,
c_tensor=c_tensor, c_scale=c_scale, c_zero=c_zero,
gemm_output_scale=gemm_output_scale, gemm_output_zero=gemm_output_zero,
data_type=data_type, quant_algo=quant_algo,
a_quant_layout=a_quant_layout, b_quant_layout=b_quant_layout,
quant_bit_size=quant_bit_size,
alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, block_m=block_m)
def flash_attn_sq_mm_allreduce(
cncl_comm: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seq_lens_q: Optional[torch.Tensor],
cu_seq_lens_kv: Optional[torch.Tensor],
alibi_slope: Optional[torch.Tensor],
attn_bias: Optional[torch.Tensor],
smooth: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor],
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int = -1,
window_size_right: int = -1,
compute_dtype: torch.dtype = torch.float,
block_seq: int = 0) -> torch.Tensor:
return tmo.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v,
cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, smooth, weight, weight_scale,
bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left,
window_size_right, compute_dtype, block_seq)
#Moe inner kernels
def moe_softmax_topk(input: torch.Tensor,
topk: int,
normalize: bool = False,
num_expert_group: int = -1,
topk_group: int = 0) -> Tuple[torch.Tensor]:
return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group)
def moe_gen_idx(expert_id: torch.Tensor,
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return tmo.moe_gen_idx(expert_id, expert_num)
def moe_expand_input(input: torch.Tensor,
gather_idx: torch.Tensor,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_expand_input(input, gather_idx,
cusum_token_count,
start_expert_id, expert_size)
def moe_active(input: torch.Tensor,
act_mode: str,
is_gated: bool,
output: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_active(input, act_mode, is_gated, output,
bias, cusum_token_count,
start_expert_id, expert_size)
def group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
max_m: int = 0
) -> torch.Tensor:
return tmo.group_gemm(a, b, m_list, expand_idx,
c, alpha, beta, max_m)
def smooth_quant_group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
a_scale: torch.Tensor,
b_scale: torch.Tensor,
dtype,
max_m: int = 0
) -> torch.Tensor:
return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta,
a_scale, b_scale, dtype, max_m)
def moe_combine_result(input: torch.Tensor,
reduce_weight: torch.Tensor,
gather_ids: torch.Tensor,
residual: Optional[torch.Tensor],
cusum_token_count: Optional[torch.Tensor],
start_expert_id: int,
expert_size: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return tmo.moe_combine_result(input, reduce_weight, gather_ids,
residual, cusum_token_count,
start_expert_id, expert_size, bias)
def moe_quantize(x: torch.Tensor,
smooth: torch.Tensor,
zero: Optional[torch.Tensor] = None,
token_count: Optional[torch.Tensor] = None,
gather_index: Optional[torch.Tensor] = None,
gather_index_start_position: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
dynamic_quant: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position,
output, output_scale, dynamic_quant)