Files
enginex-mlu590-vllm/vllm_mlu/_mlu_ops.py
2026-04-24 09:58:03 +08:00

1854 lines
76 KiB
Python

from contextlib import contextmanager
from typing import List, Optional, Tuple, Union
import torch
import math
import triton
import triton.language as tl
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import torch_mlu_ops as tmo
import torch_mlu_ops.triton_ops as triton_ops
except ImportError as e:
logger.warning("Failed to import from TMO OPS with %r", e)
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_data_parallel_group_world_size,
get_tp_group,
get_tp_world_group,
get_dp_group,
get_data_parallel_group_rank,
get_tp_world_world_size,
get_tp_world_rank,
get_parallel_rank_with_group,
)
@triton.jit
def _triton_advance_step(input_tokens_ptr,
sampled_token_ids_ptr,
input_positions_ptr,
seq_lens_ptr,
slot_mapping_ptr,
block_tables_ptr,
block_tables_stride,
num_seqs,
num_queries,
block_size,
TILE_SIZE: tl.constexpr,
):
"""
The triton implementation of advance step.
Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55
"""
# Set meta info.
pid = tl.program_id(axis=0)
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
mask = offsets < num_queries
# Update input_tokens.
sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask)
tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask)
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
next_seq_lens = seq_lens + 1
next_input_pos = next_seq_lens - 1
# Update seq_lens.
tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask)
# Update input_positions.
tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask)
# Calculate slot num.
block_index = next_input_pos // block_size
block_offset = next_input_pos % block_size
block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask)
slot_num = block_tables * block_size + block_offset
# Update slot_mapping.
tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask)
def rotary_embedding(
input: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
position_ids: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
interleaved: bool,
discrete: bool,
dynamic_ntk: bool,
max_seqlen: int,
) -> torch.Tensor:
return tmo.apply_rotary(
input, sin_cache, cos_cache,
position_ids, cu_seqlens, interleaved,
discrete, dynamic_ntk, max_seqlen)
def fused_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
out: torch.Tensor = None,
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
return tmo.fused_rms_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
out, dynamic_quant)
def fused_layer_norm(
x: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
bias: torch.Tensor,
eps: float,
store_output_before_norm: bool,
quant_scale: torch.Tensor = None,
dynamic_quant: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
return tmo.fused_layer_norm(
x, residual, gamma, beta, bias,
eps, store_output_before_norm, quant_scale,
None, dynamic_quant)
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor],
cu_seq_lens_q: Optional[torch.Tensor],
cu_seq_lens_kv: Optional[torch.Tensor],
alibi_slope: Optional[torch.Tensor],
attn_bias: Optional[torch.Tensor],
max_seq_len_q: int,
max_seq_len_kv: int,
softmax_scale: float,
is_causal: bool,
window_size_left: int = -1,
window_size_right: int = -1,
compute_dtype: torch.dtype = torch.float,
return_lse: bool = False,
block_tables: Optional[torch.Tensor] = None,
out_quant_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
q_quant_dtype: Optional[torch.dtype] = None,
k_quant_dtype: Optional[torch.dtype] = None,
v_quant_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if v is None:
v = k
return tmo.flash_attention(
q, k, v, out,
cu_seq_lens_q, cu_seq_lens_kv,
alibi_slope, attn_bias,
max_seq_len_q, max_seq_len_kv,
softmax_scale, is_causal,
window_size_left, window_size_right,
compute_dtype, return_lse,
block_tables, k_quant_scale,
v_quant_scale, q_quant_scale,
out_quant_scale, out_dtype)
def split_head_nums(q_head_num, kv_head_num, max_q_head_num):
"""
Split q_head_num such that:
1. The maximum value of the split q_head_num does not exceed max_q_head_num.
2. kv_head_num is split into the same number of parts as q_head_num.
3. Each split q_head_num can be evenly divided by the corresponding kv_head_num.
4. If kv_head_num < 1, it is adjusted to 1.
Parameters:
- q_head_num: int, the q_head_num to be split.
- kv_head_num: int, the kv_head_num to be split.
- max_q_head_num: int, the maximum supported q_head_num after splitting.
Returns:
- q_splits: list, the split q_head_num.
- kv_splits: list, the split kv_head_num.
"""
if q_head_num <= 0 or kv_head_num <= 0:
return "q_head_num and kv_head_num must be positive integers!"
q_splits = []
kv_splits = []
# Residual value
remaining_q = q_head_num
remaining_kv = kv_head_num
while remaining_q > 0:
# Attempt to split q_head_num such that the maximum value does not exceed max_q_head_num.
for q_part in range(min(max_q_head_num, remaining_q), 0, -1):
# Ensure that q_part can be allocated and the corresponding kv_part is greater than or equal to 1.
if remaining_q % q_part == 0:
# Ensure that kv_part is greater than or equal to 1.
kv_part = max(remaining_kv // (remaining_q // q_part), 1)
# Ensure that q_part is divisible by kv_part.
if q_part % kv_part == 0:
# Record the split values.
q_splits.append(q_part)
kv_splits.append(kv_part)
remaining_q -= q_part
remaining_kv -= kv_part
break
else:
err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}"
raise RuntimeError(err_msg)
return q_splits, kv_splits
def repeat_elements(input_list, n):
"""
Repeat each element in the list n times consecutively.
Parameters:
- input_list: list, the input list.
- n: int, the number of times each element should be repeated.
Returns:
- list, a new list containing the repeated elements.
"""
if not isinstance(input_list, list) or not isinstance(n, int) or n < 0:
raise ValueError("The input must be a list, and the repetition count n must be an integer greater than or equal to 0.")
# Repeat each element n times using a list comprehension.
return [item for item in input_list for _ in range(n)]
def single_query_cached_kv_attn(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
k_cache_quant_scale: Optional[torch.Tensor],
v_cache_quant_scale: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
max_contxt_len: int,
windows_size_left: int,
windows_size_right: int,
softmax_scale: float,
return_lse: bool = False,
q_head_dim: Optional[int] = 2,
kv_head_dim: Optional[int] = 1,
seq_q_dim: Optional[int] = 1,
max_seq_q_mul_q_divide_kv: Optional[int] = 128,
head_size_v: Optional[int] = -1,
compute_dtype: Optional[torch.dtype] = torch.float32,
q_quant_scale: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
q_quant_dtype: Optional[torch.dtype] = None,
q_scale_dtype: Optional[torch.dtype] = None,
learnable_sink: Optional[torch.Tensor] = None,
) -> None:
windows_size_right = -1
seq_q = q.shape[seq_q_dim]
if q_quant_dtype is not None and q.dtype != q_quant_dtype and q_quant_scale is None:
q, q_quant_scale = tmo.scaled_quantize(q.contiguous(), quant_type=q_quant_dtype, quant_mode="dynamic_per_token")
if k_cache is not None and k_cache.dtype == torch.uint8:
k_cache = k_cache.view(torch.float8_e4m3fn)
if v_cache is not None and v_cache.dtype == torch.uint8:
v_cache = v_cache.view(torch.float8_e4m3fn)
if k_cache is not None and k_cache.dtype == torch.bfloat16:
max_seq_q_mul_q_divide_kv = 256
# single_query_cached_kv_attn limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now,
# and this limitation only applies when using kv8 or floating point computation.
# When the limitation is fixed, we should delete the split process.
q_head_num = q.shape[q_head_dim]
kv_head_num = k_cache.shape[kv_head_dim]
q_divide_kv = q_head_num // kv_head_num
if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv or q_quant_scale is not None:
tmo.single_query_cached_kv_attn(
q, k_cache, v_cache, out,
block_tables, context_lens,
k_cache_quant_scale, v_cache_quant_scale,
alibi_slopes, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale, return_lse,
q_quant_scale=q_quant_scale,
head_size_v=head_size_v,
compute_dtype=compute_dtype,
mask=mask,
learnable_sink=learnable_sink)
else:
max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q
q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num)
parts_num = len(q_head_num_sizes)
q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim)
out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim)
alibi_slopes_parts = [None] * parts_num
if alibi_slopes:
alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0)
kv_parts_num = parts_num
if parts_num > kv_head_num:
assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num"
kv_parts_num = kv_head_num
kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num]
if len(kv_head_num_sizes) > 1:
k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim)
v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim)
k_cache_quant_scale_parts = [None] * kv_parts_num
v_cache_quant_scale_parts = [None] * kv_parts_num
if k_cache_quant_scale:
k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim
k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim)
if v_cache_quant_scale:
v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim
v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim)
else:
k_cache_parts = [k_cache]
v_cache_parts = [v_cache]
k_cache_quant_scale_parts = [k_cache_quant_scale]
v_cache_quant_scale_parts = [v_cache_quant_scale]
if parts_num > kv_parts_num:
repeate_num = parts_num // kv_parts_num
k_cache_parts = repeat_elements(k_cache_parts, repeate_num)
v_cache_parts = repeat_elements(v_cache_parts, repeate_num)
k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num)
v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num)
for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip(
q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts,
alibi_slopes_parts):
tmo.single_query_cached_kv_attn(
q_value, k_cache_value.contiguous(), v_cache_value.contiguous() if v_cache_value is not None else None,
out_value, block_tables, context_lens,
k_cache_quant_scale_value, v_cache_quant_scale_value,
alibi_slopes_value, max_contxt_len,
windows_size_left, windows_size_right, softmax_scale, return_lse,
head_size_v=head_size_v,
compute_dtype=compute_dtype)
return(None, None) # TODO(liangxuegang): to fix return (output, lse)
def reshape_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor
) -> None:
tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
def swap_blocks(
dst: torch.Tensor,
src: torch.Tensor,
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
block_mapping = {src: dst for src, dst in block_mapping}
return tmo.swap_blocks(dst, src, block_mapping)
def copy_blocks(
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
block_mapping: torch.Tensor
) -> None:
# FIXME: Remove this conversion after
# tmo.swap_blocks support block_mapping tensor.
block_mapping = block_mapping.tolist()
result_dict = {}
for row in block_mapping:
key = row[0]
values = row[1:]
if key in result_dict:
result_dict[key].extend(values)
else:
result_dict[key] = values
return tmo.copy_blocks(k_caches, v_caches, result_dict)
def active(
input: torch.Tensor,
act_mode: str,
is_gated: bool
) -> torch.Tensor:
return tmo.active(input, act_mode, is_gated)
def fused_moe(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
input_smooth: Optional[torch.Tensor],
act_smooth: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
topk: int,
renormalize: bool,
gated: bool,
act_mode: str,
start_expert_id: int = 0,
block_n: int = 0,
cncl_comm: int = 0,
avg_moe: bool=False,
class_reduce_weight: Optional[torch.Tensor] = None,
class_expert_id: Optional[torch.Tensor] = None,
w1_quant_flag: Optional[List] = None,
w2_quant_flag: Optional[List] = None,
world_size: int = 0,
shared_expert_num: int = 0,
parallel_mode: str = 'ep'):
dtype = hidden_states.dtype
ori_input_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
tokens = hidden_states.size(0)
gating_output = gating_output.reshape(-1, gating_output.size(-1))
residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None
expert_num = gating_output.size(-1)
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
per_token_sq = False
# check quant
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError(
"input_smooth, act_smooth, w1_scale and w2_scale must be "
"present and absent at the same time."
)
# softmax_topk
reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalize)
# append shared
if shared_expert_num > 0:
reduce_weight, expert_id = tmo.moe_append_shared_expert(reduce_weight, expert_id, expert_num,
shared_expert_num, world_size, parallel_mode)
if parallel_mode == "ep":
avg_shared_expert_num = (world_size + shared_expert_num - 1) // world_size
expert_num += avg_shared_expert_num * world_size
else:
expert_num += shared_expert_num
if avg_moe:
n_tokens = hidden_states.shape[0]
reduce_weight = class_reduce_weight[:n_tokens]
expert_id = class_expert_id[:n_tokens]
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num)
if per_token_sq:
quant_input, input_scale = tmo.moe_quantize(hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
cusum_token_count[start_expert_id].unsqueeze(0))
else:
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
# group gemm
if per_token_sq:
gemm1_out = tmo.smooth_quant_group_gemm(quant_input,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None,
input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag)
else:
gemm1_out = tmo.group_gemm(expand_hidden_states,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None,
None,
None,
None, tokens)
if per_token_sq:
quant_input = quant_input[:, :gemm1_out.shape[-1] // 2] if gated else quant_input[:, :gemm1_out.shape[-1]]
input_scale = input_scale[:gemm1_out.shape[0]]
quant_input, input_scale = tmo.moe_quantize(gemm1_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size],
output=quant_input,
output_scale=input_scale,
act_mode=act_mode,
is_gated=gated)
else:
act_out = gemm1_out[:, :gemm1_out.shape[-1] // 2] if gated else gemm1_out
act_out = tmo.moe_active(gemm1_out, act_mode, gated, act_out, bias1, cusum_token_count, start_expert_id, expert_size)
if cncl_comm > 0:
raise ValueError("not support communication and computing fusion currently.")
else:
if per_token_sq:
gemm2_out = tmo.smooth_quant_group_gemm(quant_input,
w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag)
else:
gemm2_out = tmo.group_gemm(act_out,
w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, tokens)
output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id,
expert_size, bias2)
return output.reshape(ori_input_shape)
def matmul(
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
act_mode: str = 'none',
alpha: float = 1.0,
beta: float = .0
) -> torch.Tensor:
return tmo.matmul(a, b, bias, c, act_mode, alpha, beta)
def weight_only_quant_matmul(
a: torch.Tensor,
b: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0
) -> torch.Tensor:
assert False, "[weight_only_quant_matmul] is deprecated."
def smooth_quant_matmul(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
alpha: float = 1.0,
beta: float = 1.0,
use_hp_active: bool = False,
b_quant_bit_size: int = 8,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return tmo.scaled_matmul(a, b, a_scale, b_scale, dtype, bias, c, act_mode,
b_quant_bit_size, alpha, beta, use_hp_active)
def per_token_smooth_quantize(x: torch.Tensor,
smooth: torch.Tensor,
zero: torch.Tensor = None,
token_count: torch.Tensor = None,
act_mode: str = "none",
active_coef: float = 1.0,
is_gated: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if act_mode == "none":
is_gated = False
if token_count is None:
output, output_scale = tmo.scaled_quantize(x, smooth, zero, None, torch.int8,
"dynamic_per_token", act_mode, active_coef,
is_gated)
else:
output, output_scale = tmo.moe_quantize(x, smooth, zero, token_count, None, None, None,
None, True, act_mode, active_coef, is_gated)
return (output, output_scale)
def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None
) -> torch.Tensor:
assert False, "[quantize] is deprecated."
def quant_to_paged_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_cache_quant_scale: torch.Tensor,
v_cache_quant_scale: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if k_cache is not None and k_cache.dtype == torch.uint8:
k_cache = k_cache.view(torch.float8_e4m3fn)
if v_cache is not None and v_cache.dtype == torch.uint8:
v_cache = v_cache.view(torch.float8_e4m3fn)
return tmo.quant_to_paged_cache(
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping
)
def advance_step(num_seqs: int,
num_queries: int,
block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
TILE_SIZE: int = 64) -> None:
"""
Advance a step on MLU for existing inputs for a multi-step runner, which
will update input_tokens/seq_lens/input_positions/slot_mapping inplace.
"""
def verify_tensor(
name: str,
tensor: torch.Tensor,
size_0: int,
size_1: int,
dtype: torch.dtype,
):
"""
Auxiliary function to check whether input is valid.
"""
size_0_cond = (size_0 == -1 or tensor.size(0) == size_0)
size_1_cond = (size_1 == -1 or tensor.size(1) == size_1)
if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype):
raise ValueError(
f"The input to advance_step is invalid with tensor name = {name}, "
f"shape = {tensor.shape}, "
f"is_cont = {tensor.is_contiguous()}, "
f"type = {tensor.dtype}, "
f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}"
)
verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64)
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64)
verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32)
verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32)
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32)
verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32)
grid = (math.ceil(num_queries / TILE_SIZE), )
_triton_advance_step[grid](input_tokens,
sampled_token_ids,
input_positions,
seq_lens,
slot_mapping,
block_tables,
block_tables.stride(0),
num_seqs,
num_queries,
block_size,
TILE_SIZE)
#Moe inner kernels
def moe_softmax_topk(input: torch.Tensor,
topk: int,
normalize: bool = False,
num_expert_group: int = -1,
topk_group: int = 0,
mask: Optional[torch.Tensor] = None,
normed_by : str = "topk_logit",
route_scale : float = 1.0,
reduce_weight: Optional[torch.Tensor] = None,
expert_id: Optional[torch.Tensor] = None,
score_bias: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]:
return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group,
topk_group, mask, normed_by, route_scale,
reduce_weight, expert_id, score_bias)
def moe_sigmoid_topk(input: torch.Tensor,
topk: int,
normalize: bool = False,
num_expert_group: int = -1,
topk_group: int = 0,
route_scale: float = 1.0,
score_bias: Optional[torch.Tensor] = None,
reduce_weight: Optional[torch.Tensor] = None,
expert_id: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]:
return tmo.moe_sigmoid_topk(input, topk, normalize, num_expert_group,
topk_group, route_scale = route_scale,
score_bias = score_bias,
reduce_weight=reduce_weight,
expert_id=expert_id)
def moe_softplus_topk(
input: torch.Tensor,
topk: int,
input_ids: Optional[torch.Tensor] = None,
tid2eid: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
route_scale: float = 1.0,
reduce_weight: Optional[torch.Tensor] = None,
expert_id: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.moe_softplus_topk(
input,
topk,
input_ids,
tid2eid,
bias,
route_scale,
reduce_weight,
expert_id,
)
def moe_gen_idx(expert_id: torch.Tensor,
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return tmo.moe_gen_idx(expert_id, expert_num)
def moe_expand_input(input: torch.Tensor,
gather_idx: torch.Tensor,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_expand_input(input, gather_idx,
cusum_token_count,
start_expert_id, expert_size)
def moe_active(input: torch.Tensor,
act_mode: str,
is_gated: bool,
output: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cusum_token_count: Optional[torch.Tensor] = None,
start_expert_id: int = 0,
expert_size: int = 0) -> torch.Tensor:
return tmo.moe_active(input, act_mode, is_gated, output,
bias, cusum_token_count,
start_expert_id, expert_size)
def group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
max_m: int = 0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return tmo.group_gemm(a, b, m_list, expand_idx,
c, alpha, beta, max_m, d=output)
def smooth_quant_group_gemm(a: torch.Tensor,
b: torch.Tensor,
m_list: torch.Tensor,
expand_idx: Optional[torch.Tensor],
c: Optional[torch.Tensor],
alpha: Optional[torch.Tensor],
beta: Optional[torch.Tensor],
a_scale: torch.Tensor,
b_scale: torch.Tensor,
dtype,
max_m: int = 0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta,
a_scale, b_scale, dtype, max_m, d=output)
def moe_combine_result(input: torch.Tensor,
reduce_weight: torch.Tensor,
gather_ids: torch.Tensor,
residual: Optional[torch.Tensor],
cusum_token_count: Optional[torch.Tensor],
start_expert_id: int,
expert_size: int,
bias: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return tmo.moe_combine_result(input, reduce_weight, gather_ids,
residual, cusum_token_count,
start_expert_id, expert_size, bias, output=output)
def moe_quantize(x: torch.Tensor,
smooth: torch.Tensor,
zero: Optional[torch.Tensor] = None,
token_count: Optional[torch.Tensor] = None,
gather_index: Optional[torch.Tensor] = None,
gather_index_start_position: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
dynamic_quant: bool = True,
act_mode: str = "none",
active_coef: float = 1.0,
is_gated: bool = False,
quant_type: torch.dtype = torch.int8
) -> Tuple[torch.Tensor, torch.Tensor]:
return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position,
output, output_scale, dynamic_quant, act_mode, active_coef, is_gated, quant_type)
def dequant_from_paged_cache(key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
key_cache_quant_scale: torch.Tensor,
value_cache_quant_scale: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
context_seq_offset: Optional[torch.Tensor],
block_tables: torch.Tensor,
quant_mode: int = 0,
quant_bit: int = 8) -> None:
tmo.dequant_from_paged_cache(
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale,
context_lengths, max_context_len, context_seq_offset, block_tables, quant_mode, quant_bit)
def random_sample(
probs: torch.Tensor,
is_gumbel_max: bool,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
return tmo.random_sample(probs, is_gumbel_max, generators)
def rejection_sample(draft_token_ids: torch.Tensor,
num_draft_tokens: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
uniform_rand: torch.Tensor,
uniform_probs: torch.Tensor,
max_spec_len: int,
high_acc: bool = True) -> torch.Tensor:
return tmo.rejection_sample(
draft_token_ids, num_draft_tokens, cu_num_draft_tokens, draft_probs,
target_probs, bonus_token_ids, uniform_rand, uniform_probs, max_spec_len, high_acc)
def apply_topkp_v2(logits: torch.Tensor,
index_in: torch.Tensor,
temperature_list: torch.Tensor,
minp_list: torch.Tensor,
topk_list: torch.Tensor,
topp_list: torch.Tensor,
logits_out: Optional[torch.Tensor] = None,
sorted_logits_out: Optional[torch.Tensor] = None,
index_out: Optional[torch.Tensor] = None,
true_select_len: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return tmo.apply_topkp_v2(logits, index_in, temperature_list, minp_list, topk_list, topp_list,
logits_out, sorted_logits_out, index_out, true_select_len)
def scaled_quantize(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
zero: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
quant_type: torch.dtype = torch.int8,
quant_mode: str = "dynamic_per_token",
act_mode: str = "none",
active_coef: float = 1.0,
is_gated: bool = False
) -> Tuple[torch.Tensor]:
"""
Apply activation and quantization to the input tensor x.
Args:
x (torch.Tensor): The tensor to be quantized, shape is (..., C), must be continuous between 0 and -2 dimensions.
scale (Optional[torch.Tensor], optional): The scale multipled to the input tensor. Shape is (C) or (1).
zero (Optional[torch.Tensor], optional): Not supported, must pass None.
scale_ub (Optional[torch.Tensor], optional): The output_scale upper bound.
Take effect only if quant_type == torch.float8_e4m3fn and quant_mode == "dynamic_per_token".
quant_type (optional): Output data type, can be torch.int8, torch.float8_e4m3fn. Defaults to torch.int8.
quant_mode (str, optional): quantize mode, which can be "dynamic_per_token", "dynamic_per_tensor", "static_per_tensor"
and "static_per_channel". Defaults to "dynamic_per_token".
act_mode (str): The mode of activation, must be "none", "gelu", "silu", "swish".
active_coef(float): The coefficient used in the swish activation. Default is 1.0.
is_gated (bool): A boolean parameter that indicates whether a gating mechanism is applied. It only
takes effect when act_mode is not "none".
Type:
input: float, half, bfloat16.
scale: float.
scale_ub: float.
act_mode: str
active_coef: float
is_gated: bool
Returns:
Tuple[torch.Tensor]: Returns (output, output_scale) if quant_mode is "dynamic_per_token" or "dynamic_per_tensor",
otherwise returns output only.
"""
return tmo.scaled_quantize(input,
scale,
zero,
scale_ub,
quant_type,
quant_mode,
act_mode,
active_coef,
is_gated)
def scaled_matmul(a: torch.Tensor,
b: torch.Tensor,
a_scale: Optional[torch.Tensor],
b_scale: torch.Tensor,
output_dtype: torch.dtype,
bias: torch.Tensor = None,
c: torch.Tensor = None,
act_mode: str = "none",
quant_bit_size: int = 8,
alpha: float = 1.0,
beta: float = 1.0,
use_hp_active: bool = False,
a_quant_bit_size: int = 8,
a_calib: Optional[torch.Tensor] = None,
b_calib: Optional[torch.Tensor] = None,):
"""
Perform quantized matrix multiplication on tensor a and b.
Args:
a (torch.Tensor): Shape is (M, K).
b (torch.Tensor): If quant_bit_size = 8, shape is (N, K).
If quant_bit_size = 4, shape is (N, K//2).
a_scale (Optional[torch.Tensor]): Shape can be (M).
b_scale (torch.Tensor): If use groupwise quantization, shape must be (N, group_num), data type must be
the same as a; otherwise shape must be (N), data type must be float.
output_dtype (torch.dtype): Specify the data type of output, must be torch.half or torch.bfloat16.
bias (torch.Tensor, optional): Shape is (N).
c (torch.Tensor, optional): Shape is (M, N).
act_mode (str, optional): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. If use groupwise
quantization, act_mode must be 'none'.
quant_bit_size (int, optional): The data format of b. Defaults to 8.
alpha (float, optional): coefficient of acted. Defaults to 1.0.
beta (float, optional): coefficient of c. Defaults to 1.0.
use_hp_active (bool, optional): Describing the algorithm that used in the implementation of the activation function.
When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation.
Defaults to False.
a_quant_bit_size(int, optional):The data format of a. Defaults to -1.
a_calib (Optional[torch.Tensor]): The calibration of a, shape can be (M, 2).
b_calib (Optional[torch.Tensor]): The calibration of b, shape can be (M, 2).
Type:
a: int8, half, bfloat16, float8_e4m3fn, int4X2
a_scale: float
b: int8, float8_e4m3fn, int4X2
b_scale: float, half, bfloat16
bias: half, float, bfloat16
c: half, float, bfloat16
output: half, bfloat16
a_calib: float
b_calib: float
Returns:
A tensor with the shape of (M, N).
"""
return tmo.scaled_matmul(a,
b,
a_scale,
b_scale,
output_dtype,
bias,
c,
act_mode,
quant_bit_size,
alpha,
beta,
use_hp_active,
a_quant_bit_size,
a_calib,
b_calib,)
def fused_mla_kv(kv: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
position_id: torch.Tensor,
gamma: torch.Tensor,
kv_cache: torch.Tensor,
kv_cache_scale: Optional[torch.Tensor],
slot_mapping: Optional[torch.Tensor],
cache_bs_id: Optional[torch.Tensor] = None,
cache_seq_offset: Optional[torch.Tensor] = None,
is_paged_cache: bool = True,
eps: float = 1e-5,
interleaved: bool = True):
quant_mode = "static_per_channel" if kv_cache_scale is None else "dynamic_per_token"
return tmo.fused_mla_kv(
kv, sin, cos, position_id, gamma, kv_cache, kv_cache_scale, slot_mapping, cache_bs_id,
cache_seq_offset,
quant_mode=quant_mode,
is_paged_cache=is_paged_cache,
eps=eps,
interleaved=interleaved,
)
def fused_mla_q(q: torch.Tensor,
gamma: torch.Tensor,
smooth_quant_scale: torch.Tensor,
weight_b: torch.Tensor,
weight_b_scale: torch.Tensor,
weight_c: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
position_id: torch.Tensor,
output: Optional[torch.Tensor] = None,
eps: float = 1e-6,
interleaved: bool = True,
output_quant_mode: str = 'none',
output_scale: Optional[torch.Tensor] = None,
output_norm: Optional[torch.Tensor] = None) -> torch.Tensor:
return tmo.fused_mla_q(
q, gamma, smooth_quant_scale, weight_b, weight_b_scale, weight_c, sin, cos, position_id,
output, eps, interleaved, output_quant_mode, output_scale,
store_norm=(output_norm is not None),
output_norm= output_norm,
)
def gather_cache(
kv_cache: List[torch.Tensor], # [[1, num_blocks, num_kv_heads, block_size, head_size]
# [1, num_blocks, num_kv_heads, block_size] if kv_cache_dtype=int8]
dst: torch.Tensor, # [tot_tokens, entrys...]
block_table: torch.Tensor, # [batch, block_indices]
cu_seq_lens: torch.Tensor, # [batch+1]
batch_size: int,
seq_starts: torch.Tensor = None, # Optional: [batch]
kv_cache_dtype: str = 'auto',
) -> None:
"""
Gathers sequences from src_cache into dst based on block_table and cu_seq_lens.
Args:
src_cache: Source KV cache tensor of shape [[1, num_blocks, num_kv_heads, block_size, head_size],
[1, num_blocks, num_kv_heads, block_size] if cache_dtype=int8].
dst: Destination tensor of shape [tot_tokens, entrys...].
block_table: Tensor of shape [batch, block_indices] mapping sequences to blocks.
cu_seq_lens: Tensor of shape [batch+1] with cumulative sequence lengths.
batch_size: Number of sequences in the batch.
seq_starts: Optional tensor of shape [batch] for block index offsets.
"""
assert len(kv_cache) > 0 and kv_cache[0].numel() > 0, "kv cache can't be empty in gather_cache"
src_cache = kv_cache[0][0]
# Validate inputs
assert src_cache.device == dst.device == block_table.device == cu_seq_lens.device, \
"All tensors must be on the same device"
assert block_table.dtype == torch.int32, "block_table must be int32"
assert cu_seq_lens.dtype == torch.int32, "cu_seq_lens must be int32"
quant_kv_cache = kv_cache_dtype != 'auto'
if not quant_kv_cache:
assert src_cache.dtype == dst.dtype, "src_cache and dst must have the same dtype when no quantized"
if seq_starts is not None:
assert seq_starts.dtype == torch.int32, "seq_starts must be int32"
assert seq_starts.device == src_cache.device, "seq_starts must be on the same device"
# Extract dimensions
num_blocks, num_kv_heads, block_size, head_size = src_cache.shape
# When using MLA during decode it becomes MQA, the num_kv_heads is fixed to 1,
# so src_cache can be view to [num_blocks, block_size, head_size]
assert num_kv_heads == 1, "mla force num_kv_heads to 1"
src_cache = src_cache.view(num_blocks, block_size, -1)
entry_shape = src_cache.shape[2:] # ENTRIES...
tot_tokens = cu_seq_lens[-1]
assert tot_tokens > 0, "tot_tokens should > 0"
assert tot_tokens <= dst.shape[0], "tot_tokens should <= dst.shape[0]"
dst_cache = dst[:tot_tokens]
# Ensure cu_seq_lens matches batch_size
assert cu_seq_lens.size(0) == batch_size + 1, "cu_seq_lens must have batch_size + 1 elements"
# Compute sequence lengths
seq_lens = cu_seq_lens[1:] - cu_seq_lens[:-1] # [BATCH]
tot_blocks_per_seq = (seq_lens + block_size - 1) // block_size # ceil_div
# Handle seq_starts offset
block_offsets = torch.zeros(batch_size, dtype=torch.int32, device=src_cache.device)
if seq_starts is not None:
block_offsets = seq_starts // block_size
# Flatten src_cache for easier indexing: [NUM_BLOCKS * BLOCK_SIZE, ENTRIES...]
src_flat = src_cache.view(num_blocks * block_size, *entry_shape)
# Prepare output indices
dst_indices = []
for bid in range(batch_size):
seq_len = seq_lens[bid]
if seq_len <= 0:
continue
seq_start = cu_seq_lens[bid]
tot_blocks = tot_blocks_per_seq[bid]
offset = block_offsets[bid]
# Compute block indices for this sequence
block_ids = block_table[bid, offset:offset + tot_blocks]
# Compute token indices within blocks
token_indices = torch.arange(seq_len, device=src_cache.device)
block_indices = token_indices // block_size
within_block = token_indices % block_size
# Map to src_flat indices
src_indices = block_ids[block_indices] * block_size + within_block
dst_indices.append(src_indices)
# Concatenate all indices
dst_indices = torch.cat(dst_indices)
# Gather data
dst_flat = src_flat[dst_indices]
if quant_kv_cache:
src_cache_scale = kv_cache[1][0]
src_scale_flat = src_cache_scale.view(num_blocks * block_size)
dst_scale_flat = src_scale_flat[dst_indices]
dst_flat = dst_flat * dst_scale_flat.unsqueeze(-1)
dst_cache.view(-1, *entry_shape).copy_(dst_flat.view(tot_tokens, *entry_shape))
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
"""
Merges partial attention states (prefix and suffix) into a single output.
Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005.
Args:
output: Output tensor of shape [num_tokens, num_query_heads, head_size].
prefix_output: Prefix attention output, same shape as output.
prefix_lse: Prefix log-sum-exp, shape [num_query_heads, num_tokens].
suffix_output: Suffix attention output, same shape as output.
suffix_lse: Suffix log-sum-exp, same shape as prefix_lse.
output_lse: Optional output log-sum-exp, same shape as prefix_lse.
"""
# Input validation
assert output.shape == prefix_output.shape == suffix_output.shape, \
"Output and input tensors must have the same shape"
assert prefix_lse.shape == suffix_lse.shape, \
"Prefix and suffix LSE tensors must have the same shape"
if output_lse is not None:
assert output_lse.shape == prefix_lse.shape, \
"Output LSE must have the same shape as input LSE tensors"
# Handle inf values (replace inf with -inf for consistency)
p_lse = torch.where(
prefix_lse == float('inf'),
torch.tensor(float('-inf'), device=prefix_lse.device),
prefix_lse
)
s_lse = torch.where(
suffix_lse == float('inf'),
torch.tensor(float('-inf'), device=suffix_lse.device),
suffix_lse
)
# Compute maximum LSE for numerical stability
max_lse = torch.maximum(p_lse, s_lse) # Shape: [num_query_heads, num_tokens]
# Normalize LSE terms
p_lse = p_lse - max_lse # Shape: [num_query_heads, num_tokens]
s_lse = s_lse - max_lse # Shape: [num_query_heads, num_tokens]
# Compute sum of exponentials
out_se = torch.exp(p_lse) + torch.exp(s_lse) # Shape: [num_query_heads, num_tokens]
# Compute output_lse if provided
if output_lse is not None:
output_lse.copy_(torch.log(out_se) + max_lse)
# Compute scaling factors
p_scale = torch.exp(p_lse) / out_se # Shape: [num_query_heads, num_tokens]
s_scale = torch.exp(s_lse) / out_se # Shape: [num_query_heads, num_tokens]
# Reshape scales for broadcasting
p_scale = p_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1]
s_scale = s_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1]
# Transpose outputs to match scaling dimensions
prefix_output = prefix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size]
suffix_output = suffix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size]
# Compute merged output
out = prefix_output * p_scale + suffix_output * s_scale # Shape: [num_query_heads, num_tokens, head_size]
# Transpose back and store in output
output.copy_(out.permute(1, 0, 2)) # Shape: [num_tokens, num_query_heads, head_size]
def moe_all2all_create(dispatch_token_byte: int,
combine_token_byte: int,
max_expert_num: int,
max_token_num: int,
rank: int,
nrank: int) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create the handle of MOE All-to-All communication.
API call order:
1.Call torch_mlu_ops.moe_all2all_create(...) to obtain the CNCLEP handle and buffer tensor for All-to-All communication. Only needs to be done once.
2.Gather all_exchange_info by performing an All-Gather operation on exchange_info across nrank processes. Only needs to be done once.
3.Call torch.distributed.barrier() to ensure step 2 finish. Only needs to be done once.
4.Call torch_mlu_ops.moe_all2all_init(...) to configure the all_exchange_info into the handle. Only needs to be done once.
5.Call torch_mlu_ops.moe_all2all_dispatch(...) to route tokens to their designated experts.
6.Call torch_mlu_ops.moe_all2all_combine(...) to restore tokens to their original locations.
7.Call torch_mlu_ops.moe_all2all_destroy(...) to release the CNCLEP handle. Only needs to be done once.
Args:
dispatch_token_byte (int): Byte size of a single token for dispatch All-to-All operation.
combine_token_byte (int): Byte size of a single token for combine All-to-All operation.
max_expert_num (int): Maximum number of experts participating in the All-to-All operation.
max_token_num (int): Maximum number of tokens to be processed.
rank (int): Rank ID of the current process [0~nrank-1].
nrank (int): Total number of processes in the distributed group.
Return:
A tuple of (handle, exchange_info_size, exchange_info, dispatch_send, dispatch_recv, combine_send and combine_recv).
handle: The CNCLEP handle with type of integer.
exchange_info_size: The size of exchange_info.
exchange_info: CPU tensor, shape is [exchange_info_size], and data type is torch.int8.
dispatch_send: MLU tensor, shape is [max_token_num * dispatch_token_byte], and data type is torch.int8.
dispatch_recv: MLU tensor, shape is [nrank * max_token_num * dispatch_token_byte], and data type is torch.int8.
combine_send: MLU tensor, shape is [max_token_num * combine_token_byte], and data type is torch.int8.
combine_recv: MLU tensor, shape is [nrank * max_token_num * combine_token_byte], and data type is torch.int8.
"""
return tmo.moe_all2all_create(dispatch_token_byte, combine_token_byte, max_expert_num, max_token_num, rank, nrank)
def moe_all2all_init(handle: int,
all_exchange_info: torch.Tensor) -> None:
tmo.moe_all2all_init(handle, all_exchange_info)
def moe_all2all_destroy(handle: int) -> None:
tmo.moe_all2all_destroy(handle)
def moe_all2all_dispatch(handle: int,
token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
tmo.moe_all2all_dispatch(handle, token_byte, token_num, send_layout, send_token_num, recv_layout, recv_token_num, send_token, recv_token)
def moe_all2all_combine(handle: int,
token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
tmo.moe_all2all_combine(handle, token_byte, token_num, send_src_layout, send_dst_layout, send_token, recv_token)
def gather_split(input: torch.Tensor,
gather_index: torch.Tensor,
valid_token_num: torch.Tensor,
output1: torch.Tensor,
output2: Optional[torch.Tensor] = None) -> None:
tmo.gather_split(input,
gather_index,
valid_token_num,
output1,
output2)
def moe_all2all_gen_send_layout(token_count: torch.Tensor,
nrank: int) -> torch.Tensor:
return tmo.moe_all2all_gen_send_layout(token_count, nrank)
def moe_all2all_gen_gather_index(token_num: torch.Tensor, pad_num: int,
return_cusum_token_count: bool = False):
if not return_cusum_token_count:
gather_by_expert_index, gather_by_rank_index, token_count, token_sum = \
tmo.moe_all2all_gen_gather_index(token_num, pad_num)
return gather_by_expert_index, gather_by_rank_index, token_count, token_sum
else:
gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count = \
tmo.moe_all2all_gen_gather_index(token_num, pad_num, return_cusum_token_count=True)
return gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count
def reshape_from_cache(
key: torch.Tensor,
value: Optional[torch.Tensor],
key_cache: torch.Tensor,
value_cache: Optional[torch.Tensor],
context_lengths: torch.Tensor,
max_context_len: int,
context_seq_offset: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
cache_seq_offset: Optional[torch.Tensor] = None,
) -> None:
tmo.reshape_from_cache(
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
context_lengths=context_lengths,
max_context_len=max_context_len,
context_seq_offset=context_seq_offset,
block_tables=block_tables,
cache_seq_offset=cache_seq_offset,
)
def masked_indexer_select_paged_kv(query: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
kv_cache_block_table: torch.Tensor,
cu_seq_q_lens: Optional[torch.Tensor],
cu_seq_k_lens: Optional[torch.Tensor],
k_context_lens: Optional[torch.Tensor],
k_cache_block_table: Optional[torch.Tensor],
is_prefill: bool,
index_topk: int,
kv_cache_block_size: int,
softmax_scale: float,
q_scale: Optional[torch.Tensor] = None,
k_scale_cache: Optional[torch.Tensor] = None,
sparse_block_table: Optional[torch.Tensor] = None,
sparse_context_lens: Optional[torch.Tensor] = None):
tmo.masked_indexer_select_paged_kv(query=query,
k_cache=k_cache,
weights=weights,
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
k_context_lens=k_context_lens,
k_cache_block_table=k_cache_block_table,
is_prefill=is_prefill,
index_topk=index_topk,
kv_cache_block_size=kv_cache_block_size,
softmax_scale=softmax_scale,
q_scale=q_scale,
k_scale_cache=k_scale_cache,
sparse_block_table=sparse_block_table,
sparse_context_lens=sparse_context_lens)
def masked_indexer_select_paged_kv_prefill(
query: torch.Tensor,
key_value: torch.Tensor,
weights: torch.Tensor,
kv_cache_block_table: torch.Tensor,
cu_seq_q_lens: torch.Tensor,
cu_seq_k_lens: torch.Tensor,
index_topk: int,
kv_cache_block_size: int,
softmax_scale: float,
q_scale: Optional[torch.Tensor] = None,
k_scale_cache: Optional[torch.Tensor] = None,
sparse_block_table: Optional[torch.Tensor] = None,
sparse_context_lens: Optional[torch.Tensor] = None,
kv_cache_block_table_offset: Optional[torch.Tensor] = None,
compress_ratio: int = 1,
):
return tmo.masked_indexer_select_paged_kv(
query=query,
k_cache=key_value,
weights=weights,
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
k_context_lens=None,
k_cache_block_table=None,
is_prefill=True,
index_topk=index_topk,
kv_cache_block_size=kv_cache_block_size,
softmax_scale=softmax_scale,
q_scale=q_scale,
k_scale_cache=k_scale_cache,
sparse_block_table=sparse_block_table,
sparse_context_lens=sparse_context_lens,
kv_cache_block_table_offset=kv_cache_block_table_offset,
compress_ratio=compress_ratio,
is_score_float=True,
)
def masked_indexer_select_paged_kv_decode(
query: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
kv_cache_block_table: torch.Tensor,
k_context_lens: Optional[torch.Tensor],
k_cache_block_table: Optional[torch.Tensor],
index_topk: int,
kv_cache_block_size: int,
softmax_scale: float,
q_scale: Optional[torch.Tensor] = None,
k_scale_cache: Optional[torch.Tensor] = None,
sparse_block_table: Optional[torch.Tensor] = None,
sparse_context_lens: Optional[torch.Tensor] = None,
kv_cache_block_table_offset: Optional[torch.Tensor] = None,
compress_ratio: int = 1,
):
query_len = query.shape[1]
#k_context_lens = k_context_lens // compress_ratio
return tmo.masked_indexer_select_paged_kv(
query=query,
k_cache=k_cache,
weights=weights,
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=None,
cu_seq_k_lens=None,
k_context_lens=k_context_lens,
k_cache_block_table=k_cache_block_table,
is_prefill=False,
index_topk=index_topk,
kv_cache_block_size=kv_cache_block_size,
softmax_scale=softmax_scale,
q_scale=q_scale,
k_scale_cache=k_scale_cache,
sparse_block_table=sparse_block_table,
sparse_context_lens=sparse_context_lens,
kv_cache_block_table_offset=kv_cache_block_table_offset,
compress_ratio=compress_ratio,
is_score_float=True,
)
def concat_block_table(
first_block_table: torch.Tensor,
first_context_lens: torch.Tensor,
second_block_table: torch.Tensor,
second_context_lens: torch.Tensor,
new_block_table: Optional[torch.Tensor] = None,
new_context_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Concatenate two different block tables, return the concatenated result.
Math:
new_context_lens = first_context_lens + second_context_lens
total_seq = first_context_lens.size(0)
for i in range(total_seq):
new_block_table[i, :first_context_lens[i]] = first_block_table[i, :first_context_lens[i]]
new_block_table[i, first_context_lens[i]:first_context_lens[i]+second_context_lens[i]] = second_block_table[i, :second_context_lens[i]]
Args:
first_block_table (torch.Tensor):
The first block table of shape `[total_seq, first_max_blkn]`.
first_context_lens (torch.Tensor):
The context lens of the first block table of shape `[total_seq,]`.
second_block_table (torch.Tensor):
The second block table of shape `[total_seq, second_max_blkn]`.
second_context_lens (torch.Tensor):
The context lens of the second block table of shape `[total_seq,]`.
new_block_table (Optional[torch.Tensor]):
The new block table of shape `[total_seq, max_new_block_number]`.
if not None, the max_new_block_number must be large enough for the concatenated block_table
Default: `None`.
new_context_lens (Optional[torch.Tensor]):
The new context lens of shape `[total_seq,]`. Default: `None`.
Returns:
new_block_table (torch.Tensor):
The concatenated block table of shape `[total_seq, max_new_block_number]`.
new_context_lens (torch.Tensor):
The new context lens of shape `[total_seq,]`, equals first_context_lens + second_context_lens
Type:
INT32
"""
return tmo.concat_block_table(
first_block_table,
first_context_lens,
second_block_table,
second_context_lens,
new_block_table,
new_context_lens,
)
def fused_mhc_post(
x: torch.Tensor, # (N, D) float|bf16
residual: torch.Tensor, # (N, HC, D) float|bf16
post: torch.Tensor, # (N, HC) 固定为float
comb: torch.Tensor, # (N, HC, HC) 固定为float
compute_rms: bool,
eps: float,
output: torch.Tensor = None, # (N, HC, D) 同输入类型
output_rms = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Math:
output = post * x + (comb * residual).sum(dim=1)
output_rms = rsqrt(x.square().mean(dim=-1))
Args:
x (torch.Tensor): Shape is [N, D].
residual (torch.Tensor): Shape is [N, HC, D].
post (torch.Tensor): Shape is [N, HC].
comb (torch.Tensor): Shape is [N, HC, HC].
compute_rms (bool): Whether to compute output_rms.
eps (float): The eps of normalization.
output (torch.Tensor, optional): Shape is [N, HC, D]. Defaults to None.
output_rms (torch.Tensor, optional): Shape is [N]. Defaults to None.
Returns:
output, output_rms
Limitation:
D must be 4096.
HC must be 4.
"""
out = tmo.fused_mhc_post(
x,
residual,
post,
comb,
compute_rms,
eps,
output,
output_rms,
)
return out if compute_rms else (out, None)
def fused_compress_multi_kv(kv: torch.Tensor, # (BS, D) float|bf16
score: torch.Tensor, # (BS, D) float|bf16
kv_state: torch.Tensor, # (max_B, coff * R, D) float
score_state: torch.Tensor, # (max_B, coff * R, D) float
batch_ids: torch.Tensor, # (B,) int32
cu_seqlens: torch.Tensor, # (B,) int32
ape: torch.Tensor, # (R, D) float
max_seqlen:int,
overlap: bool,
compressed_kv: torch.Tensor # (BS, head_dim) float|bf16
):
tmo.fused_compress_multi_kv(
kv = kv,
score = score,
kv_state = kv_state,
score_state = score_state,
cu_seqlens = cu_seqlens,
batch_ids = batch_ids,
ape = ape,
max_seqlen = max_seqlen,
overlap = overlap,
compressed_kv = compressed_kv,
)
def fused_compress_single_kv(
kv: torch.Tensor, # (T, D) float|bf16
score: torch.Tensor, # (T, D) float|bf16
position: torch.Tensor, # (B,) int32
ape: torch.Tensor, # (ratio, D) float|bf16
kv_state: torch.Tensor, # (B, R, D) float|bf16
score_state: torch.Tensor, # (B, R, D) float|bf16
gamma: torch.Tensor, # (d)
sin: torch.Tensor, # (-1, rope_dim)
cos: torch.Tensor, # (-1, rope_dim)
hadamard_matrix: Optional[torch.Tensor], # (d, d)
slot_mapping: torch.Tensor, # (B,) int32
kv_cache: torch.Tensor, # (-1, BLKS, head_dim) bf16|int8|fp8
kv_cache_scale: Optional[torch.Tensor], # (-1, BLKS) float
eps: float,
overlap: bool,
rotate: bool,
state_idx: torch.Tensor,
cu_query_len: torch.Tensor | None = None,
):
"""
Math:
Args:
kv (torch.Tensor): Shape is [B, S, D].
score (torch.Tensor): Shape is [B, S, D].
position (torch.Tensor): Shape is [B].
ape (torch.Tensor): Shape is [ratio, D].
kv_state (torch.Tensor): Shape is [max_B, R, D].
score_state (torch.Tensor): Shape is [max_B, R, D].
gamma (torch.Tensor): Shape is [head_dim].
sin (torch.Tensor): Shape is [table_len, rope_dim].
cos (torch.Tensor): Shape is [table_len, rope_dim].
hadamard_matrix (torch.Tensor): Shape is [head_dim, head_dim].
slot_mapping (torch.Tensor): Shape is [B].
kv_cache (torch.Tensor): Shape is [cache_len, block_size, hs].
kv_cache_scale (torch.Tensor): Shape is [cache_len, block_size].
eps (flost): The eps of normalization.
overlap (bool): Whether to overlap.
rotate (bool): Whether to rotate.
Type:
kv: BF16, FP32
score: same as kv
position: INT32
ape: FP32
kv_state: FP32
score_state: FP32
gamma: same as kv
sin: same as kv
cos: same as kv
hadamard_matrix: same as kv
slot_mapping: INT32
kv_cache: BF16, FP32
kv_cache_scale: FP32
Returns:
Only support inplace outputs, include kv_state, score_state, kv_cache, kv_cache_scale
Note:
coff = overlap + 1
D = coff * head_dim
R = coff * ratio
"""
token_num, coff_dim = kv.shape
# TODO: force user_tmo = 0 after supporting mtp.
bsz = state_idx.numel()
kv = kv.unsqueeze(1)
score = score.unsqueeze(1)
if kv_cache.dim() == 4:
paged_num, head_num, block_size, head_dim = kv_cache.shape
assert head_num == 1
kv_cache = kv_cache.view(paged_num, block_size, head_dim)
return tmo.fused_compress_single_kv(
kv=kv,
score=score,
position=position,
state_ids=state_idx,
ape=ape,
kv_state=kv_state,
score_state=score_state,
gamma=gamma,
sin=sin,
cos=cos,
hadamard_matrix=hadamard_matrix if rotate else None,
slot_mapping=slot_mapping,
kv_cache=kv_cache,
kv_cache_scale=kv_cache_scale,
eps=eps,
overlap=overlap,
)
def convertBlockTable(block_table, blks, incseq):
if blks == 1:
return block_table
else:
expanded = block_table.unsqueeze(1).repeat(1, blks)
result = expanded * blks + incseq
return result.flatten()
def get_window_block_tables(window_size : int,
block_size : int, #blocksize of block_table
seq_k_lens: torch.Tensor,
query_start_loc: torch.Tensor,
block_table: Optional[torch.Tensor]=None, # shape (batch, max_blocks)
window_block_tables:Optional[torch.Tensor]=None, # shape (total_seq, max_blocks)
window_context_lens:Optional[torch.Tensor]=None): # shape (total_seq)
tmo.get_window_block_tables(window_block_tables = window_block_tables,
window_context_lens = window_context_lens,
seq_k_lens = seq_k_lens,
query_start_loc = query_start_loc,
block_table = block_table,
block_size = block_size,
window_size = window_size,)
def get_compress_block_tables(ratio: int,
block_size: int,
seq_k_lens: torch.Tensor, # k lens before compression, shape (batch)
query_start_loc: torch.Tensor, # shape (batch+1)
offset: torch.Tensor, # shape (batch)
block_table: torch.Tensor, # shape (batch, max_blocks)
compress_block_tables: torch.Tensor, # shape (total_seq, max_blocks)
compress_context_lens: torch.Tensor): # shape (total_seq)
tmo.get_compress_block_tables(
compress_block_tables = compress_block_tables,
compress_context_lens = compress_context_lens,
seq_k_lens = seq_k_lens,
query_start_loc = query_start_loc,
offset = offset,
block_table = block_table,
block_size = block_size,
ratio = ratio,
)
def hc_split_sinkhorn(mixes: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
pre_scale: Optional[torch.Tensor] = None,
hc_mult: int = 4,
sinkhorn_iter: int = 20,
eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return tmo.hc_split_sinkhorn(
mixes = mixes,
hc_scale = hc_scale,
hc_base = hc_base,
pre_scale = pre_scale,
hc_mult = hc_mult,
sinkhorn_iter = sinkhorn_iter,
eps = eps,
)
def fused_indexer_q(q: torch.Tensor,
w_q: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
position_id: torch.Tensor,
output: Optional[torch.Tensor] = None,
hadamard_matrix: Optional[torch.Tensor] = None,
w_q_scale: Optional[torch.Tensor] = None,
output_quant_mode: str = 'none',
output_scale: Optional[torch.Tensor] = None,
interleaved: bool = True,
rope_at_front: bool = True):
return tmo.fused_indexer_q(
q = q,
w_q = w_q,
sin = sin,
cos = cos,
position_id = position_id,
output = output,
hadamard_matrix = hadamard_matrix,
w_q_scale = w_q_scale,
output_quant_mode = output_quant_mode,
output_scale = output_scale,
interleaved = interleaved,
rope_at_front = rope_at_front)
def fused_mla_q_v2(
input_q: torch.Tensor,
gamma: torch.Tensor,
smooth_quant_scale: Optional[torch.Tensor],
weight_b: torch.Tensor,
weight_b_scale: Optional[torch.Tensor],
sin: torch.Tensor,
cos: torch.Tensor,
position_id: torch.Tensor,
output: Optional[torch.Tensor] = None,
eps: float = 1e-6,
interleaved: bool = True,
store_norm: bool = False,
output_norm: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
This function applies MLA (Multi-head Latent Attention) v2 Query (Q) preprocessing.
The fusion logic includes: RMSNorm -> Quant(Optional) -> MatMul -> RMSNorm -> RoPE.
Math:
qr = rmsnorm(input_q, gamma, eps)
if quant:
qr, q_scale = per_token_quant(norm_out, smooth_quant_scale)
q = matmul(qr, q_scale, weight_b, weight_b_scale)
q = q.reshape(batch, seq, n_local_heads, head_dim)
q = rsqrt(q.square().mean(-1, keepdim=True) + eps)
out = apply_rotary_embedding(q, sin, cos, position_id, interleaved)
Args:
input_q (torch.Tensor):
The input latent query tensor. Shape is (batch, seq, q_lora_rank).
gamma (torch.Tensor):
The scaling parameter for the initial RMSNorm. Shape is (q_lora_rank).
smooth_quant_scale (Optional[torch.Tensor]):
Scale tensor for SmoothQuant migration. Can be None. Shape is (q_lora_rank).
weight_b (torch.Tensor):
The Q-projection weight tensor. Shape is (n_local_heads, head_dim, q_lora_rank).
weight_b_scale (Optional[torch.Tensor]):
The per-channel quantization scales for weight_b. Shape is (n_local_heads, head_dim).
sin (torch.Tensor):
Rotary embedding sine table. Shape is (max_rotary_seq_len, rotary_head_dim).
cos (torch.Tensor):
Rotary embedding cosine table. Shape is (max_rotary_seq_len, rotary_head_dim).
position_id (torch.Tensor):
Indices for the RoPE tables. Shape is (batch,).
output (Optional[torch.Tensor]):
Optional output tensor for the final processed Q. Shape is (batch, seq, n_local_heads, head_dim).
eps (float):
Small constant for RMSNorm numerical stability. Default: 1e-6.
interleaved (bool):
If True, apply interleaved rotary embedding, otherwise folded. Default: True.
store_norm (bool):
If True, the intermediate RMSNorm result (pre-MatMul) will be returned. Default: False.
output_norm (Optional[torch.Tensor]):
Optional tensor to store the intermediate RMSNorm result. Shape: (batch, seq, q_lora_rank).
Type:
input_q, gamma, sin, cos: bfloat16.
weight_b: int8, same as input_q.
weight_b_scale, smooth_quant_scale: float32.
position_id: int32.
output: same as input_q.
Return:
Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
- If store_norm=False: output
- If store_norm=True: (..., output_norm) is appended to the return.
"""
return tmo.fused_mla_q_v2(
input_q=input_q,
gamma=gamma,
smooth_quant_scale=smooth_quant_scale,
weight_b=weight_b,
weight_b_scale=weight_b_scale,
sin=sin,
cos=cos,
position_id=position_id,
output=output,
eps=eps,
interleaved=interleaved,
store_norm=store_norm,
output_norm=output_norm,
)
def update_compressor_states(
kv_state, # (max_batch, (overlap+1)*ratio + K, dim)
score_state, # (max_batch, (overlap+1)*ratio + K, dim)
accept_tokens: torch.Tensor, # (bsz,)
batch_to_kv_state: torch.Tensor, # (bsz,)
positions: torch.Tensor, # (bsz,)
cu_query_len: torch.Tensor, # (bsz+1,)
overlap: bool,
K: int
):
bsz = batch_to_kv_state.numel()
ratio = (kv_state.size(1) - K) // (overlap + 1)
start_positions = positions[cu_query_len[:bsz]]
end_positions = start_positions + accept_tokens
for i in range(bsz):
start_pos = start_positions[i]
end_pos = end_positions[i]
# Skip if sequence len does not exceed coff * ratio.
if (overlap and end_pos < 2 * ratio) or (not overlap and end_pos < ratio):
continue
# Skip if compression condition does not meets.
if (start_pos // ratio) == (end_pos // ratio) and start_pos % ratio != 0:
continue
state_idx = batch_to_kv_state[i]
if overlap:
length = end_pos - start_pos + start_pos % ratio
else:
length = end_pos % ratio
start = ratio
end = start + length
if length == 0:
continue
kv_state[state_idx, :length] = kv_state[state_idx, start:end].clone()
score_state[state_idx, :length] = score_state[state_idx, start:end].clone()