forked from EngineX-Cambricon/enginex-mlu370-vllm
779 lines
27 KiB
Python
779 lines
27 KiB
Python
from typing import List, Optional, Tuple
|
||
|
||
import torch
|
||
|
||
import math
|
||
import triton
|
||
import triton.language as tl
|
||
|
||
from vllm.logger import init_logger
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
try:
|
||
import torch_mlu_ops as tmo
|
||
except ImportError as e:
|
||
logger.warning("Failed to import from TMO OPS with %r", e)
|
||
|
||
|
||
def rotary_embedding(
|
||
input: torch.Tensor,
|
||
sin_cache: torch.Tensor,
|
||
cos_cache: torch.Tensor,
|
||
position_ids: Optional[torch.Tensor],
|
||
cu_seqlens: Optional[torch.Tensor],
|
||
interleaved: bool,
|
||
discrete: bool,
|
||
dynamic_ntk: bool,
|
||
max_seqlen: int,
|
||
) -> torch.Tensor:
|
||
return tmo.apply_rotary(
|
||
input, sin_cache, cos_cache,
|
||
position_ids, cu_seqlens, interleaved,
|
||
discrete, dynamic_ntk, max_seqlen)
|
||
|
||
|
||
def fused_rms_norm(
|
||
x: torch.Tensor,
|
||
residual: torch.Tensor,
|
||
gamma: torch.Tensor,
|
||
beta: torch.Tensor,
|
||
bias: torch.Tensor,
|
||
eps: float,
|
||
store_output_before_norm: bool,
|
||
quant_scale: torch.Tensor = None,
|
||
dynamic_quant: bool = False,
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
return tmo.fused_rms_norm(
|
||
x, residual, gamma, beta, bias,
|
||
eps, store_output_before_norm, quant_scale,
|
||
None, dynamic_quant)
|
||
|
||
|
||
def fused_layer_norm(
|
||
x: torch.Tensor,
|
||
residual: torch.Tensor,
|
||
gamma: torch.Tensor,
|
||
beta: torch.Tensor,
|
||
bias: torch.Tensor,
|
||
eps: float,
|
||
store_output_before_norm: bool,
|
||
quant_scale: torch.Tensor = None,
|
||
dynamic_quant: bool = False,
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||
return tmo.fused_layer_norm(
|
||
x, residual, gamma, beta, bias,
|
||
eps, store_output_before_norm, quant_scale,
|
||
None, dynamic_quant)
|
||
|
||
|
||
def flash_attention(
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
out: torch.Tensor,
|
||
cu_seq_lens_q: torch.Tensor,
|
||
cu_seq_lens_kv: torch.Tensor,
|
||
alibi_slope: torch.Tensor,
|
||
attn_bias: torch.Tensor,
|
||
max_seq_len_q: int,
|
||
max_seq_len_kv: int,
|
||
softmax_scale: float,
|
||
is_causal: bool,
|
||
window_size_left: int = -1,
|
||
window_size_right: int = -1,
|
||
compute_dtype: torch.dtype = torch.float,
|
||
return_lse: bool = False,
|
||
block_tables: torch.Tensor = None,
|
||
k_cache_quant_scale: torch.Tensor = None,
|
||
v_cache_quant_scale: torch.Tensor = None
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
return tmo.flash_attention(
|
||
q, k, v, out,
|
||
cu_seq_lens_q, cu_seq_lens_kv,
|
||
alibi_slope, attn_bias,
|
||
max_seq_len_q, max_seq_len_kv,
|
||
softmax_scale, is_causal,
|
||
window_size_left, window_size_right,
|
||
compute_dtype, return_lse,
|
||
block_tables, k_cache_quant_scale,
|
||
v_cache_quant_scale)
|
||
|
||
|
||
def split_head_nums(q_head_num, kv_head_num, max_q_head_num):
|
||
"""
|
||
对 q_head_num 进行切分,使得:
|
||
1. 切分后的 q_head_num 最大值不超过 max_q_head_num
|
||
2. kv_head_num 按 q_head_num 相同份数拆分;
|
||
3. 每个切分后的 q_head_num 可以被对应的 kv_head_num 整除;
|
||
4. 若 kv_head_num < 1,则调整为 1。
|
||
|
||
参数:
|
||
- q_head_num: int, 需要切分的 q_head_num。
|
||
- kv_head_num: int, 需要切分的 kv_head_num。
|
||
- max_q_head_num: int, 支持切分后最大的 q_head_num
|
||
|
||
返回:
|
||
- q_splits: list, 切分后的 q_head_num。
|
||
- kv_splits: list, 切分后的 kv_head_num。
|
||
"""
|
||
if q_head_num <= 0 or kv_head_num <= 0:
|
||
return "q_head_num 和 kv_head_num 必须是正整数!"
|
||
|
||
q_splits = []
|
||
kv_splits = []
|
||
|
||
# 剩余值
|
||
remaining_q = q_head_num
|
||
remaining_kv = kv_head_num
|
||
|
||
while remaining_q > 0:
|
||
# 尝试切分 q_head_num,最大值不超过 max_q_head_num
|
||
for q_part in range(min(max_q_head_num, remaining_q), 0, -1):
|
||
# 确保 q_part 能被分配并且对应的 kv_part >= 1
|
||
if remaining_q % q_part == 0:
|
||
kv_part = max(remaining_kv // (remaining_q // q_part), 1) # 确保 kv_part >= 1
|
||
if q_part % kv_part == 0: # 确保 q_part 可以被 kv_part 整除
|
||
# 记录切分值
|
||
q_splits.append(q_part)
|
||
kv_splits.append(kv_part)
|
||
remaining_q -= q_part
|
||
remaining_kv -= kv_part
|
||
break
|
||
else:
|
||
err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}"
|
||
raise RuntimeError(err_msg)
|
||
|
||
return q_splits, kv_splits
|
||
|
||
|
||
def repeat_elements(input_list, n):
|
||
"""
|
||
将列表的每个成员连续重复 n 次。
|
||
|
||
参数:
|
||
- input_list: list,输入的列表。
|
||
- n: int,每个元素需要重复的次数。
|
||
|
||
返回:
|
||
- list,包含重复元素的新列表。
|
||
"""
|
||
if not isinstance(input_list, list) or not isinstance(n, int) or n < 0:
|
||
raise ValueError("输入必须是一个列表,并且重复次数 n 必须是大于或等于 0 的整数。")
|
||
|
||
# 使用列表推导式重复每个元素 n 次
|
||
return [item for item in input_list for _ in range(n)]
|
||
|
||
|
||
def single_query_cached_kv_attn(
|
||
q: torch.Tensor,
|
||
k_cache: torch.Tensor,
|
||
v_cache: torch.Tensor,
|
||
out: torch.Tensor,
|
||
block_tables: torch.Tensor,
|
||
context_lens: torch.Tensor,
|
||
k_cache_quant_scale: Optional[torch.Tensor],
|
||
v_cache_quant_scale: Optional[torch.Tensor],
|
||
alibi_slopes: Optional[torch.Tensor],
|
||
max_contxt_len: int,
|
||
windows_size_left: int,
|
||
windows_size_right: int,
|
||
softmax_scale: float,
|
||
q_head_dim: Optional[int] = 2,
|
||
kv_head_dim: Optional[int] = 1,
|
||
seq_q_dim: Optional[int] = 1,
|
||
max_seq_q_mul_q_divide_kv: Optional[int] = 48,
|
||
) -> None:
|
||
# FIXME(chenxiaobing): TMO only support windows_size_right = -1 yet.
|
||
windows_size_right = -1
|
||
|
||
# singleQwithkvCache limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now.
|
||
# When the limitation is fixed, we should delete the split process.
|
||
seq_q = q.shape[seq_q_dim]
|
||
q_head_num = q.shape[q_head_dim]
|
||
kv_head_num = k_cache.shape[kv_head_dim]
|
||
q_divide_kv = q_head_num // kv_head_num
|
||
if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv:
|
||
tmo.single_query_cached_kv_attn(
|
||
q, k_cache, v_cache, out,
|
||
block_tables, context_lens,
|
||
k_cache_quant_scale, v_cache_quant_scale,
|
||
alibi_slopes, max_contxt_len,
|
||
windows_size_left, windows_size_right, softmax_scale)
|
||
else:
|
||
max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q
|
||
q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num)
|
||
parts_num = len(q_head_num_sizes)
|
||
q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim)
|
||
out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim)
|
||
alibi_slopes_parts = [None] * parts_num
|
||
if alibi_slopes:
|
||
alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0)
|
||
|
||
kv_parts_num = parts_num
|
||
if parts_num > kv_head_num:
|
||
assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num"
|
||
kv_parts_num = kv_head_num
|
||
kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num]
|
||
|
||
if len(kv_head_num_sizes) > 1:
|
||
k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim)
|
||
v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim)
|
||
k_cache_quant_scale_parts = [None] * kv_parts_num
|
||
v_cache_quant_scale_parts = [None] * kv_parts_num
|
||
if k_cache_quant_scale:
|
||
k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim
|
||
k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim)
|
||
if v_cache_quant_scale:
|
||
v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim
|
||
v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim)
|
||
else:
|
||
k_cache_parts = [k_cache]
|
||
v_cache_parts = [v_cache]
|
||
k_cache_quant_scale_parts = [k_cache_quant_scale]
|
||
v_cache_quant_scale_parts = [v_cache_quant_scale]
|
||
|
||
if parts_num > kv_parts_num:
|
||
repeate_num = parts_num // kv_parts_num
|
||
k_cache_parts = repeat_elements(k_cache_parts, repeate_num)
|
||
v_cache_parts = repeat_elements(v_cache_parts, repeate_num)
|
||
k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num)
|
||
v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num)
|
||
|
||
for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip(
|
||
q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts,
|
||
alibi_slopes_parts):
|
||
tmo.single_query_cached_kv_attn(
|
||
q_value, k_cache_value.contiguous(), v_cache_value.contiguous(), out_value,
|
||
block_tables, context_lens,
|
||
k_cache_quant_scale_value, v_cache_quant_scale_value,
|
||
alibi_slopes_value, max_contxt_len,
|
||
windows_size_left, windows_size_right, softmax_scale)
|
||
|
||
|
||
def reshape_linear_cache(
|
||
key: torch.Tensor,
|
||
value: Optional[torch.Tensor],
|
||
key_cache: torch.Tensor,
|
||
value_cache: Optional[torch.Tensor],
|
||
context_lengths: torch.Tensor,
|
||
max_context_len: int,
|
||
packed: bool,
|
||
context_seq_offset: Optional[torch.Tensor],
|
||
cache_bs_id: Optional[torch.Tensor],
|
||
cache_seqlen_offset: Optional[torch.Tensor],
|
||
) -> None:
|
||
tmo.reshape_linear_cache(
|
||
key, value,
|
||
key_cache, value_cache,
|
||
context_lengths, max_context_len,
|
||
packed, context_seq_offset,
|
||
cache_bs_id, cache_seqlen_offset)
|
||
|
||
|
||
def reshape_paged_cache(
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
k_cache: torch.Tensor,
|
||
v_cache: torch.Tensor,
|
||
slot_mapping: torch.Tensor
|
||
) -> None:
|
||
tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
|
||
|
||
|
||
def swap_blocks(
|
||
dst: torch.Tensor,
|
||
src: torch.Tensor,
|
||
block_mapping: torch.Tensor
|
||
) -> None:
|
||
# FIXME: Remove this conversion after
|
||
# tmo.swap_blocks support block_mapping tensor.
|
||
block_mapping = block_mapping.tolist()
|
||
block_mapping = {src: dst for src, dst in block_mapping}
|
||
return tmo.swap_blocks(dst, src, block_mapping)
|
||
|
||
|
||
def copy_blocks(
|
||
k_caches: List[torch.Tensor],
|
||
v_caches: List[torch.Tensor],
|
||
block_mapping: torch.Tensor
|
||
) -> None:
|
||
# FIXME: Remove this conversion after
|
||
# tmo.swap_blocks support block_mapping tensor.
|
||
block_mapping = block_mapping.tolist()
|
||
result_dict = {}
|
||
for row in block_mapping:
|
||
key = row[0]
|
||
values = row[1:]
|
||
if key in result_dict:
|
||
result_dict[key].extend(values)
|
||
else:
|
||
result_dict[key] = values
|
||
return tmo.copy_blocks(k_caches, v_caches, result_dict)
|
||
|
||
|
||
def ffn(
|
||
input: torch.Tensor,
|
||
up_fc_weight: torch.Tensor,
|
||
up_fc_bias: Optional[torch.Tensor],
|
||
down_proj_weight: torch.Tensor,
|
||
down_proj_bias: Optional[torch.Tensor],
|
||
gate_up_proj_weight: Optional[torch.Tensor] = None,
|
||
gate_up_proj_bias: Optional[torch.Tensor] = None,
|
||
act_mode: str = "none"
|
||
) -> torch.Tensor:
|
||
return tmo.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias,
|
||
gate_up_proj_weight, gate_up_proj_bias, act_mode)
|
||
|
||
|
||
def active(
|
||
input: torch.Tensor,
|
||
act_mode: str,
|
||
is_gated: bool
|
||
) -> torch.Tensor:
|
||
return tmo.active(input, act_mode, is_gated)
|
||
|
||
|
||
def fused_moe(
|
||
hidden_states: torch.Tensor,
|
||
gating_output: torch.Tensor,
|
||
w1: torch.Tensor,
|
||
w2: torch.Tensor,
|
||
bias1: Optional[torch.Tensor],
|
||
bias2: Optional[torch.Tensor],
|
||
residual: Optional[torch.Tensor],
|
||
input_smooth: Optional[torch.Tensor],
|
||
act_smooth: Optional[torch.Tensor],
|
||
w1_scale: Optional[torch.Tensor],
|
||
w2_scale: Optional[torch.Tensor],
|
||
topk: int,
|
||
renormalize: bool,
|
||
gated: bool,
|
||
act_mode: str,
|
||
start_expert_id: int = 0,
|
||
block_n: int = 0,
|
||
cncl_comm: int = 0
|
||
) -> torch.Tensor:
|
||
return tmo.fused_moe(
|
||
hidden_states, gating_output,
|
||
w1, w2, bias1, bias2, residual,
|
||
input_smooth, act_smooth,
|
||
w1_scale, w2_scale, topk,
|
||
renormalize, gated, act_mode, start_expert_id,
|
||
block_n, cncl_comm)
|
||
|
||
|
||
def matmul(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
bias: Optional[torch.Tensor] = None,
|
||
c: Optional[torch.Tensor] = None,
|
||
act_mode: str = 'none',
|
||
alpha: float = 1.0,
|
||
beta: float = .0
|
||
) -> torch.Tensor:
|
||
return tmo.matmul(a, b, bias, c, act_mode, alpha, beta)
|
||
|
||
|
||
def weight_only_quant_matmul(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
scale: torch.Tensor,
|
||
zero: torch.Tensor = None,
|
||
bias: torch.Tensor = None,
|
||
c: torch.Tensor = None,
|
||
act_mode: str = "none",
|
||
quant_bit_size: int = 8,
|
||
alpha: float = 1.0,
|
||
beta: float = 1.0
|
||
) -> torch.Tensor:
|
||
return tmo.weight_only_quant_matmul(
|
||
a, b,
|
||
scale, zero, bias, c,
|
||
act_mode, quant_bit_size, alpha, beta)
|
||
|
||
|
||
def smooth_quant_matmul(
|
||
a: torch.Tensor,
|
||
a_scale: torch.Tensor,
|
||
b: torch.Tensor,
|
||
b_scale: torch.Tensor,
|
||
dtype: torch.dtype,
|
||
bias: torch.Tensor = None,
|
||
c: torch.Tensor = None,
|
||
act_mode: str = "none",
|
||
alpha: float = 1.0,
|
||
beta: float = 1.0
|
||
) -> torch.Tensor:
|
||
return tmo.smooth_quant_matmul(
|
||
a, a_scale,
|
||
b, b_scale,
|
||
dtype, bias, c,
|
||
act_mode, alpha, beta)
|
||
|
||
|
||
def per_token_smooth_quantize(
|
||
x: torch.Tensor,
|
||
smooth: torch.Tensor,
|
||
zero: torch.Tensor = None,
|
||
token_count: torch.Tensor = None
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
return tmo.per_token_smooth_quantize(x, smooth, zero, token_count)
|
||
|
||
|
||
def quantize(
|
||
x: torch.Tensor,
|
||
scale: torch.Tensor,
|
||
zero: torch.Tensor = None
|
||
) -> torch.Tensor:
|
||
return tmo.quantize(x, scale, zero)
|
||
|
||
def quant_to_paged_cache(
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
k_cache: torch.Tensor,
|
||
v_cache: torch.Tensor,
|
||
k_cache_quant_scale: torch.Tensor,
|
||
v_cache_quant_scale: torch.Tensor,
|
||
slot_mapping: torch.Tensor,
|
||
) -> None:
|
||
return tmo.quant_to_paged_cache(
|
||
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping
|
||
)
|
||
|
||
|
||
def quant_to_linear_cache(
|
||
key: torch.Tensor,
|
||
value: Optional[torch.Tensor],
|
||
key_cache: torch.Tensor,
|
||
value_cache: Optional[torch.Tensor],
|
||
key_cache_quant_scale: torch.Tensor,
|
||
value_cache_quant_scale: Optional[torch.Tensor],
|
||
context_lengths: torch.Tensor,
|
||
max_context_len: int,
|
||
packed: bool,
|
||
context_seq_offset: Optional[torch.Tensor],
|
||
cache_bs_id: Optional[torch.Tensor],
|
||
cache_seqlen_offset: Optional[torch.Tensor],
|
||
) -> None:
|
||
return tmo.quant_to_linear_cache(
|
||
key,
|
||
value,
|
||
key_cache,
|
||
value_cache,
|
||
key_cache_quant_scale,
|
||
value_cache_quant_scale,
|
||
context_lengths,
|
||
max_context_len,
|
||
packed,
|
||
context_seq_offset,
|
||
cache_bs_id,
|
||
cache_seqlen_offset,
|
||
)
|
||
|
||
|
||
def advance_step(num_seqs: int,
|
||
num_queries: int,
|
||
block_size: int,
|
||
input_tokens: torch.Tensor,
|
||
sampled_token_ids: torch.Tensor,
|
||
input_positions: torch.Tensor,
|
||
seq_lens: torch.Tensor,
|
||
slot_mapping: torch.Tensor,
|
||
block_tables: torch.Tensor,
|
||
TILE_SIZE: int = 64) -> None:
|
||
"""
|
||
Advance a step on MLU for existing inputs for a multi-step runner, which
|
||
will update input_tokens/seq_lens/input_positions/slot_mapping inplace.
|
||
"""
|
||
def verify_tensor(
|
||
name: str,
|
||
tensor: torch.Tensor,
|
||
size_0: int,
|
||
size_1: int,
|
||
dtype: torch.dtype,
|
||
):
|
||
"""
|
||
Auxiliary function to check whether input is valid.
|
||
"""
|
||
size_0_cond = (size_0 == -1 or tensor.size(0) == size_0)
|
||
size_1_cond = (size_1 == -1 or tensor.size(1) == size_1)
|
||
if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype):
|
||
raise ValueError(
|
||
f"The input to advance_step is invalid with tensor name = {name}, "
|
||
f"shape = {tensor.shape}, "
|
||
f"is_cont = {tensor.is_contiguous()}, "
|
||
f"type = {tensor.dtype}, "
|
||
f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}"
|
||
)
|
||
|
||
|
||
@triton.jit
|
||
def _triton_advance_step(input_tokens_ptr,
|
||
sampled_token_ids_ptr,
|
||
input_positions_ptr,
|
||
seq_lens_ptr,
|
||
slot_mapping_ptr,
|
||
block_tables_ptr,
|
||
block_tables_stride,
|
||
num_seqs,
|
||
num_queries,
|
||
block_size,
|
||
TILE_SIZE: tl.constexpr,
|
||
):
|
||
"""
|
||
The triton implementation of advance step.
|
||
Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55
|
||
"""
|
||
# Set meta info.
|
||
pid = tl.program_id(axis=0)
|
||
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
|
||
mask = offsets < num_queries
|
||
|
||
# Update input_tokens.
|
||
sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask)
|
||
tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask)
|
||
|
||
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
|
||
next_seq_lens = seq_lens + 1
|
||
next_input_pos = next_seq_lens - 1
|
||
|
||
# Update seq_lens.
|
||
tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask)
|
||
|
||
# Update input_positions.
|
||
tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask)
|
||
|
||
# Calculate slot num.
|
||
block_index = next_input_pos // block_size
|
||
block_offset = next_input_pos % block_size
|
||
block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask)
|
||
slot_num = block_tables * block_size + block_offset
|
||
|
||
# Update slot_mapping.
|
||
tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask)
|
||
|
||
|
||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64)
|
||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64)
|
||
verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32)
|
||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32)
|
||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32)
|
||
verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32)
|
||
|
||
grid = (math.ceil(num_queries / TILE_SIZE), )
|
||
_triton_advance_step[grid](input_tokens,
|
||
sampled_token_ids,
|
||
input_positions,
|
||
seq_lens,
|
||
slot_mapping,
|
||
block_tables,
|
||
block_tables.stride(0),
|
||
num_seqs,
|
||
num_queries,
|
||
block_size,
|
||
TILE_SIZE)
|
||
|
||
def preload(
|
||
weight: torch.Tensor,
|
||
size: int
|
||
) -> None:
|
||
"""
|
||
Preload weights of layer.
|
||
|
||
Args:
|
||
weight (torch.Tensor): Weight to preload。
|
||
size (int): Preload size (byte)。
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
return tmo.preload(weight, size)
|
||
|
||
|
||
def matmul_allreduce(
|
||
cncl_comm,
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
bias: Optional[torch.Tensor] = None,
|
||
c: Optional[torch.Tensor] = None,
|
||
alpha: float = 1.0,
|
||
beta: float = .0,
|
||
block_m: int = 0
|
||
) -> torch.Tensor:
|
||
return tmo.matmul_allreduce(cncl_comm=cncl_comm,
|
||
a=a, b=b,
|
||
bias=bias, c=c,
|
||
alpha=alpha,
|
||
beta=beta,
|
||
block_m=block_m)
|
||
|
||
|
||
def smooth_quant_matmul_allreduce(
|
||
cncl_comm,
|
||
a: torch.Tensor,
|
||
a_scale: torch.Tensor,
|
||
b: torch.Tensor,
|
||
b_scale: torch.Tensor,
|
||
dtype: torch.dtype,
|
||
bias: torch.Tensor = None,
|
||
c: torch.Tensor = None,
|
||
alpha: float = 1.0,
|
||
beta: float = 1.0,
|
||
block_m: int = 0):
|
||
return tmo.smooth_quant_matmul_allreduce(
|
||
cncl_comm=cncl_comm,
|
||
a=a, a_scale=a_scale,
|
||
b=b, b_scale=b_scale,
|
||
dtype=dtype, bias=bias, c=c,
|
||
alpha=alpha, beta=beta, block_m=block_m)
|
||
|
||
|
||
def quant_matmul_allreduce(
|
||
cncl_comm,
|
||
a_tensor: torch.Tensor,
|
||
a_scale: Optional[torch.Tensor],
|
||
a_zero: Optional[torch.Tensor],
|
||
b_tensor: torch.Tensor,
|
||
b_scale: Optional[torch.Tensor],
|
||
b_zero: Optional[torch.Tensor],
|
||
bias: Optional[torch.Tensor],
|
||
c_tensor: Optional[torch.Tensor],
|
||
c_scale: Optional[torch.Tensor],
|
||
c_zero: Optional[torch.Tensor],
|
||
gemm_output_scale: Optional[torch.Tensor],
|
||
gemm_output_zero: Optional[torch.Tensor],
|
||
data_type: Optional[str],
|
||
quant_algo: str,
|
||
a_quant_layout: str,
|
||
b_quant_layout: str,
|
||
quant_bit_size: int = 8,
|
||
alpha: float = 1.0,
|
||
beta: float = 1.0,
|
||
trans_a: bool = False,
|
||
trans_b: bool = True,
|
||
block_m: int = 0
|
||
) -> torch.Tensor:
|
||
return tmo.quant_matmul_allreduce(
|
||
cncl_comm=cncl_comm, a_tensor=a_tensor, a_scale=a_scale, a_zero=a_zero,
|
||
b_tensor=b_tensor, b_scale=b_scale, b_zero=b_zero, bias=bias,
|
||
c_tensor=c_tensor, c_scale=c_scale, c_zero=c_zero,
|
||
gemm_output_scale=gemm_output_scale, gemm_output_zero=gemm_output_zero,
|
||
data_type=data_type, quant_algo=quant_algo,
|
||
a_quant_layout=a_quant_layout, b_quant_layout=b_quant_layout,
|
||
quant_bit_size=quant_bit_size,
|
||
alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, block_m=block_m)
|
||
|
||
|
||
def flash_attn_sq_mm_allreduce(
|
||
cncl_comm: int,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
cu_seq_lens_q: Optional[torch.Tensor],
|
||
cu_seq_lens_kv: Optional[torch.Tensor],
|
||
alibi_slope: Optional[torch.Tensor],
|
||
attn_bias: Optional[torch.Tensor],
|
||
smooth: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
weight_scale: torch.Tensor,
|
||
bias: Optional[torch.Tensor],
|
||
max_seq_len_q: int,
|
||
max_seq_len_kv: int,
|
||
softmax_scale: float,
|
||
is_causal: bool,
|
||
window_size_left: int = -1,
|
||
window_size_right: int = -1,
|
||
compute_dtype: torch.dtype = torch.float,
|
||
block_seq: int = 0) -> torch.Tensor:
|
||
return tmo.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v,
|
||
cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, smooth, weight, weight_scale,
|
||
bias, max_seq_len_q, max_seq_len_kv, softmax_scale, is_causal, window_size_left,
|
||
window_size_right, compute_dtype, block_seq)
|
||
|
||
#Moe inner kernels
|
||
def moe_softmax_topk(input: torch.Tensor,
|
||
topk: int,
|
||
normalize: bool = False,
|
||
num_expert_group: int = -1,
|
||
topk_group: int = 0) -> Tuple[torch.Tensor]:
|
||
return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group)
|
||
|
||
|
||
def moe_gen_idx(expert_id: torch.Tensor,
|
||
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
return tmo.moe_gen_idx(expert_id, expert_num)
|
||
|
||
def moe_expand_input(input: torch.Tensor,
|
||
gather_idx: torch.Tensor,
|
||
cusum_token_count: Optional[torch.Tensor] = None,
|
||
start_expert_id: int = 0,
|
||
expert_size: int = 0) -> torch.Tensor:
|
||
return tmo.moe_expand_input(input, gather_idx,
|
||
cusum_token_count,
|
||
start_expert_id, expert_size)
|
||
|
||
def moe_active(input: torch.Tensor,
|
||
act_mode: str,
|
||
is_gated: bool,
|
||
output: Optional[torch.Tensor] = None,
|
||
bias: Optional[torch.Tensor] = None,
|
||
cusum_token_count: Optional[torch.Tensor] = None,
|
||
start_expert_id: int = 0,
|
||
expert_size: int = 0) -> torch.Tensor:
|
||
return tmo.moe_active(input, act_mode, is_gated, output,
|
||
bias, cusum_token_count,
|
||
start_expert_id, expert_size)
|
||
|
||
def group_gemm(a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
m_list: torch.Tensor,
|
||
expand_idx: Optional[torch.Tensor],
|
||
c: Optional[torch.Tensor],
|
||
alpha: Optional[torch.Tensor],
|
||
beta: Optional[torch.Tensor],
|
||
max_m: int = 0
|
||
) -> torch.Tensor:
|
||
return tmo.group_gemm(a, b, m_list, expand_idx,
|
||
c, alpha, beta, max_m)
|
||
|
||
def smooth_quant_group_gemm(a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
m_list: torch.Tensor,
|
||
expand_idx: Optional[torch.Tensor],
|
||
c: Optional[torch.Tensor],
|
||
alpha: Optional[torch.Tensor],
|
||
beta: Optional[torch.Tensor],
|
||
a_scale: torch.Tensor,
|
||
b_scale: torch.Tensor,
|
||
dtype,
|
||
max_m: int = 0
|
||
) -> torch.Tensor:
|
||
return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta,
|
||
a_scale, b_scale, dtype, max_m)
|
||
|
||
def moe_combine_result(input: torch.Tensor,
|
||
reduce_weight: torch.Tensor,
|
||
gather_ids: torch.Tensor,
|
||
residual: Optional[torch.Tensor],
|
||
cusum_token_count: Optional[torch.Tensor],
|
||
start_expert_id: int,
|
||
expert_size: int,
|
||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
return tmo.moe_combine_result(input, reduce_weight, gather_ids,
|
||
residual, cusum_token_count,
|
||
start_expert_id, expert_size, bias)
|
||
|
||
def moe_quantize(x: torch.Tensor,
|
||
smooth: torch.Tensor,
|
||
zero: Optional[torch.Tensor] = None,
|
||
token_count: Optional[torch.Tensor] = None,
|
||
gather_index: Optional[torch.Tensor] = None,
|
||
gather_index_start_position: Optional[torch.Tensor] = None,
|
||
output: Optional[torch.Tensor] = None,
|
||
output_scale: Optional[torch.Tensor] = None,
|
||
dynamic_quant: bool = True
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position,
|
||
output, output_scale, dynamic_quant)
|