1854 lines
76 KiB
Python
1854 lines
76 KiB
Python
from contextlib import contextmanager
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
import math
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
try:
|
|
import torch_mlu_ops as tmo
|
|
import torch_mlu_ops.triton_ops as triton_ops
|
|
except ImportError as e:
|
|
logger.warning("Failed to import from TMO OPS with %r", e)
|
|
|
|
|
|
from vllm.distributed import (
|
|
get_ep_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
get_data_parallel_group_world_size,
|
|
get_tp_group,
|
|
get_tp_world_group,
|
|
get_dp_group,
|
|
get_data_parallel_group_rank,
|
|
get_tp_world_world_size,
|
|
get_tp_world_rank,
|
|
get_parallel_rank_with_group,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _triton_advance_step(input_tokens_ptr,
|
|
sampled_token_ids_ptr,
|
|
input_positions_ptr,
|
|
seq_lens_ptr,
|
|
slot_mapping_ptr,
|
|
block_tables_ptr,
|
|
block_tables_stride,
|
|
num_seqs,
|
|
num_queries,
|
|
block_size,
|
|
TILE_SIZE: tl.constexpr,
|
|
):
|
|
"""
|
|
The triton implementation of advance step.
|
|
Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55
|
|
"""
|
|
# Set meta info.
|
|
pid = tl.program_id(axis=0)
|
|
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
|
|
mask = offsets < num_queries
|
|
|
|
# Update input_tokens.
|
|
sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask)
|
|
tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask)
|
|
|
|
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
|
|
next_seq_lens = seq_lens + 1
|
|
next_input_pos = next_seq_lens - 1
|
|
|
|
# Update seq_lens.
|
|
tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask)
|
|
|
|
# Update input_positions.
|
|
tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask)
|
|
|
|
# Calculate slot num.
|
|
block_index = next_input_pos // block_size
|
|
block_offset = next_input_pos % block_size
|
|
block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask)
|
|
slot_num = block_tables * block_size + block_offset
|
|
|
|
# Update slot_mapping.
|
|
tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask)
|
|
|
|
|
|
|
|
def rotary_embedding(
|
|
input: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
cos_cache: torch.Tensor,
|
|
position_ids: Optional[torch.Tensor],
|
|
cu_seqlens: Optional[torch.Tensor],
|
|
interleaved: bool,
|
|
discrete: bool,
|
|
dynamic_ntk: bool,
|
|
max_seqlen: int,
|
|
) -> torch.Tensor:
|
|
return tmo.apply_rotary(
|
|
input, sin_cache, cos_cache,
|
|
position_ids, cu_seqlens, interleaved,
|
|
discrete, dynamic_ntk, max_seqlen)
|
|
|
|
|
|
def fused_rms_norm(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
eps: float,
|
|
store_output_before_norm: bool,
|
|
quant_scale: torch.Tensor = None,
|
|
dynamic_quant: bool = False,
|
|
out: torch.Tensor = None,
|
|
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
|
return tmo.fused_rms_norm(
|
|
x, residual, gamma, beta, bias,
|
|
eps, store_output_before_norm, quant_scale,
|
|
out, dynamic_quant)
|
|
|
|
|
|
def fused_layer_norm(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
eps: float,
|
|
store_output_before_norm: bool,
|
|
quant_scale: torch.Tensor = None,
|
|
dynamic_quant: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
return tmo.fused_layer_norm(
|
|
x, residual, gamma, beta, bias,
|
|
eps, store_output_before_norm, quant_scale,
|
|
None, dynamic_quant)
|
|
|
|
|
|
def flash_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: Optional[torch.Tensor],
|
|
cu_seq_lens_q: Optional[torch.Tensor],
|
|
cu_seq_lens_kv: Optional[torch.Tensor],
|
|
alibi_slope: Optional[torch.Tensor],
|
|
attn_bias: Optional[torch.Tensor],
|
|
max_seq_len_q: int,
|
|
max_seq_len_kv: int,
|
|
softmax_scale: float,
|
|
is_causal: bool,
|
|
window_size_left: int = -1,
|
|
window_size_right: int = -1,
|
|
compute_dtype: torch.dtype = torch.float,
|
|
return_lse: bool = False,
|
|
block_tables: Optional[torch.Tensor] = None,
|
|
out_quant_scale: Optional[torch.Tensor] = None,
|
|
out_dtype: torch.dtype = torch.half,
|
|
q_quant_dtype: Optional[torch.dtype] = None,
|
|
k_quant_dtype: Optional[torch.dtype] = None,
|
|
v_quant_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if v is None:
|
|
v = k
|
|
return tmo.flash_attention(
|
|
q, k, v, out,
|
|
cu_seq_lens_q, cu_seq_lens_kv,
|
|
alibi_slope, attn_bias,
|
|
max_seq_len_q, max_seq_len_kv,
|
|
softmax_scale, is_causal,
|
|
window_size_left, window_size_right,
|
|
compute_dtype, return_lse,
|
|
block_tables, k_quant_scale,
|
|
v_quant_scale, q_quant_scale,
|
|
out_quant_scale, out_dtype)
|
|
|
|
|
|
def split_head_nums(q_head_num, kv_head_num, max_q_head_num):
|
|
"""
|
|
Split q_head_num such that:
|
|
1. The maximum value of the split q_head_num does not exceed max_q_head_num.
|
|
2. kv_head_num is split into the same number of parts as q_head_num.
|
|
3. Each split q_head_num can be evenly divided by the corresponding kv_head_num.
|
|
4. If kv_head_num < 1, it is adjusted to 1.
|
|
|
|
Parameters:
|
|
- q_head_num: int, the q_head_num to be split.
|
|
- kv_head_num: int, the kv_head_num to be split.
|
|
- max_q_head_num: int, the maximum supported q_head_num after splitting.
|
|
|
|
Returns:
|
|
- q_splits: list, the split q_head_num.
|
|
- kv_splits: list, the split kv_head_num.
|
|
"""
|
|
if q_head_num <= 0 or kv_head_num <= 0:
|
|
return "q_head_num and kv_head_num must be positive integers!"
|
|
|
|
q_splits = []
|
|
kv_splits = []
|
|
|
|
# Residual value
|
|
remaining_q = q_head_num
|
|
remaining_kv = kv_head_num
|
|
|
|
while remaining_q > 0:
|
|
# Attempt to split q_head_num such that the maximum value does not exceed max_q_head_num.
|
|
for q_part in range(min(max_q_head_num, remaining_q), 0, -1):
|
|
# Ensure that q_part can be allocated and the corresponding kv_part is greater than or equal to 1.
|
|
if remaining_q % q_part == 0:
|
|
# Ensure that kv_part is greater than or equal to 1.
|
|
kv_part = max(remaining_kv // (remaining_q // q_part), 1)
|
|
# Ensure that q_part is divisible by kv_part.
|
|
if q_part % kv_part == 0:
|
|
# Record the split values.
|
|
q_splits.append(q_part)
|
|
kv_splits.append(kv_part)
|
|
remaining_q -= q_part
|
|
remaining_kv -= kv_part
|
|
break
|
|
else:
|
|
err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}"
|
|
raise RuntimeError(err_msg)
|
|
|
|
return q_splits, kv_splits
|
|
|
|
|
|
def repeat_elements(input_list, n):
|
|
"""
|
|
Repeat each element in the list n times consecutively.
|
|
|
|
Parameters:
|
|
- input_list: list, the input list.
|
|
- n: int, the number of times each element should be repeated.
|
|
|
|
Returns:
|
|
- list, a new list containing the repeated elements.
|
|
"""
|
|
if not isinstance(input_list, list) or not isinstance(n, int) or n < 0:
|
|
raise ValueError("The input must be a list, and the repetition count n must be an integer greater than or equal to 0.")
|
|
|
|
# Repeat each element n times using a list comprehension.
|
|
return [item for item in input_list for _ in range(n)]
|
|
|
|
|
|
def single_query_cached_kv_attn(
|
|
q: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
out: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
k_cache_quant_scale: Optional[torch.Tensor],
|
|
v_cache_quant_scale: Optional[torch.Tensor],
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
max_contxt_len: int,
|
|
windows_size_left: int,
|
|
windows_size_right: int,
|
|
softmax_scale: float,
|
|
return_lse: bool = False,
|
|
q_head_dim: Optional[int] = 2,
|
|
kv_head_dim: Optional[int] = 1,
|
|
seq_q_dim: Optional[int] = 1,
|
|
max_seq_q_mul_q_divide_kv: Optional[int] = 128,
|
|
head_size_v: Optional[int] = -1,
|
|
compute_dtype: Optional[torch.dtype] = torch.float32,
|
|
q_quant_scale: Optional[torch.Tensor] = None,
|
|
mask: Optional[torch.Tensor] = None,
|
|
q_quant_dtype: Optional[torch.dtype] = None,
|
|
q_scale_dtype: Optional[torch.dtype] = None,
|
|
learnable_sink: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
windows_size_right = -1
|
|
seq_q = q.shape[seq_q_dim]
|
|
if q_quant_dtype is not None and q.dtype != q_quant_dtype and q_quant_scale is None:
|
|
q, q_quant_scale = tmo.scaled_quantize(q.contiguous(), quant_type=q_quant_dtype, quant_mode="dynamic_per_token")
|
|
if k_cache is not None and k_cache.dtype == torch.uint8:
|
|
k_cache = k_cache.view(torch.float8_e4m3fn)
|
|
if v_cache is not None and v_cache.dtype == torch.uint8:
|
|
v_cache = v_cache.view(torch.float8_e4m3fn)
|
|
|
|
if k_cache is not None and k_cache.dtype == torch.bfloat16:
|
|
max_seq_q_mul_q_divide_kv = 256
|
|
|
|
# single_query_cached_kv_attn limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now,
|
|
# and this limitation only applies when using kv8 or floating point computation.
|
|
# When the limitation is fixed, we should delete the split process.
|
|
q_head_num = q.shape[q_head_dim]
|
|
kv_head_num = k_cache.shape[kv_head_dim]
|
|
q_divide_kv = q_head_num // kv_head_num
|
|
if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv or q_quant_scale is not None:
|
|
tmo.single_query_cached_kv_attn(
|
|
q, k_cache, v_cache, out,
|
|
block_tables, context_lens,
|
|
k_cache_quant_scale, v_cache_quant_scale,
|
|
alibi_slopes, max_contxt_len,
|
|
windows_size_left, windows_size_right, softmax_scale, return_lse,
|
|
q_quant_scale=q_quant_scale,
|
|
head_size_v=head_size_v,
|
|
compute_dtype=compute_dtype,
|
|
mask=mask,
|
|
learnable_sink=learnable_sink)
|
|
else:
|
|
max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q
|
|
q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num)
|
|
parts_num = len(q_head_num_sizes)
|
|
q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim)
|
|
out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim)
|
|
alibi_slopes_parts = [None] * parts_num
|
|
if alibi_slopes:
|
|
alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0)
|
|
|
|
kv_parts_num = parts_num
|
|
if parts_num > kv_head_num:
|
|
assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num"
|
|
kv_parts_num = kv_head_num
|
|
kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num]
|
|
|
|
if len(kv_head_num_sizes) > 1:
|
|
k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim)
|
|
v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim)
|
|
k_cache_quant_scale_parts = [None] * kv_parts_num
|
|
v_cache_quant_scale_parts = [None] * kv_parts_num
|
|
if k_cache_quant_scale:
|
|
k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim
|
|
k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim)
|
|
if v_cache_quant_scale:
|
|
v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim
|
|
v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim)
|
|
else:
|
|
k_cache_parts = [k_cache]
|
|
v_cache_parts = [v_cache]
|
|
k_cache_quant_scale_parts = [k_cache_quant_scale]
|
|
v_cache_quant_scale_parts = [v_cache_quant_scale]
|
|
|
|
if parts_num > kv_parts_num:
|
|
repeate_num = parts_num // kv_parts_num
|
|
k_cache_parts = repeat_elements(k_cache_parts, repeate_num)
|
|
v_cache_parts = repeat_elements(v_cache_parts, repeate_num)
|
|
k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num)
|
|
v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num)
|
|
|
|
for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip(
|
|
q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts,
|
|
alibi_slopes_parts):
|
|
tmo.single_query_cached_kv_attn(
|
|
q_value, k_cache_value.contiguous(), v_cache_value.contiguous() if v_cache_value is not None else None,
|
|
out_value, block_tables, context_lens,
|
|
k_cache_quant_scale_value, v_cache_quant_scale_value,
|
|
alibi_slopes_value, max_contxt_len,
|
|
windows_size_left, windows_size_right, softmax_scale, return_lse,
|
|
head_size_v=head_size_v,
|
|
compute_dtype=compute_dtype)
|
|
|
|
return(None, None) # TODO(liangxuegang): to fix return (output, lse)
|
|
|
|
|
|
def reshape_paged_cache(
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor
|
|
) -> None:
|
|
tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
|
|
|
|
|
|
def swap_blocks(
|
|
dst: torch.Tensor,
|
|
src: torch.Tensor,
|
|
block_mapping: torch.Tensor
|
|
) -> None:
|
|
# FIXME: Remove this conversion after
|
|
# tmo.swap_blocks support block_mapping tensor.
|
|
block_mapping = block_mapping.tolist()
|
|
block_mapping = {src: dst for src, dst in block_mapping}
|
|
return tmo.swap_blocks(dst, src, block_mapping)
|
|
|
|
|
|
def copy_blocks(
|
|
k_caches: List[torch.Tensor],
|
|
v_caches: List[torch.Tensor],
|
|
block_mapping: torch.Tensor
|
|
) -> None:
|
|
# FIXME: Remove this conversion after
|
|
# tmo.swap_blocks support block_mapping tensor.
|
|
block_mapping = block_mapping.tolist()
|
|
result_dict = {}
|
|
for row in block_mapping:
|
|
key = row[0]
|
|
values = row[1:]
|
|
if key in result_dict:
|
|
result_dict[key].extend(values)
|
|
else:
|
|
result_dict[key] = values
|
|
return tmo.copy_blocks(k_caches, v_caches, result_dict)
|
|
|
|
|
|
def active(
|
|
input: torch.Tensor,
|
|
act_mode: str,
|
|
is_gated: bool
|
|
) -> torch.Tensor:
|
|
return tmo.active(input, act_mode, is_gated)
|
|
|
|
def fused_moe(hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
bias1: Optional[torch.Tensor],
|
|
bias2: Optional[torch.Tensor],
|
|
residual: Optional[torch.Tensor],
|
|
input_smooth: Optional[torch.Tensor],
|
|
act_smooth: Optional[torch.Tensor],
|
|
w1_scale: Optional[torch.Tensor],
|
|
w2_scale: Optional[torch.Tensor],
|
|
topk: int,
|
|
renormalize: bool,
|
|
gated: bool,
|
|
act_mode: str,
|
|
start_expert_id: int = 0,
|
|
block_n: int = 0,
|
|
cncl_comm: int = 0,
|
|
avg_moe: bool=False,
|
|
class_reduce_weight: Optional[torch.Tensor] = None,
|
|
class_expert_id: Optional[torch.Tensor] = None,
|
|
w1_quant_flag: Optional[List] = None,
|
|
w2_quant_flag: Optional[List] = None,
|
|
world_size: int = 0,
|
|
shared_expert_num: int = 0,
|
|
parallel_mode: str = 'ep'):
|
|
dtype = hidden_states.dtype
|
|
ori_input_shape = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
|
tokens = hidden_states.size(0)
|
|
gating_output = gating_output.reshape(-1, gating_output.size(-1))
|
|
residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None
|
|
expert_num = gating_output.size(-1)
|
|
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
|
|
|
|
per_token_sq = False
|
|
# check quant
|
|
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
|
if all(x is not None for x in check_list):
|
|
per_token_sq = True
|
|
|
|
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
|
raise ValueError(
|
|
"input_smooth, act_smooth, w1_scale and w2_scale must be "
|
|
"present and absent at the same time."
|
|
)
|
|
|
|
# softmax_topk
|
|
reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalize)
|
|
|
|
# append shared
|
|
if shared_expert_num > 0:
|
|
reduce_weight, expert_id = tmo.moe_append_shared_expert(reduce_weight, expert_id, expert_num,
|
|
shared_expert_num, world_size, parallel_mode)
|
|
if parallel_mode == "ep":
|
|
avg_shared_expert_num = (world_size + shared_expert_num - 1) // world_size
|
|
expert_num += avg_shared_expert_num * world_size
|
|
else:
|
|
expert_num += shared_expert_num
|
|
|
|
if avg_moe:
|
|
n_tokens = hidden_states.shape[0]
|
|
reduce_weight = class_reduce_weight[:n_tokens]
|
|
expert_id = class_expert_id[:n_tokens]
|
|
# gen_idx
|
|
expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num)
|
|
|
|
if per_token_sq:
|
|
quant_input, input_scale = tmo.moe_quantize(hidden_states,
|
|
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
|
|
cusum_token_count[start_expert_id].unsqueeze(0))
|
|
else:
|
|
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
|
|
cusum_token_count, start_expert_id, expert_size)
|
|
|
|
# group gemm
|
|
if per_token_sq:
|
|
gemm1_out = tmo.smooth_quant_group_gemm(quant_input,
|
|
w1,
|
|
token_count[start_expert_id:start_expert_id+expert_size],
|
|
None, None, None, None,
|
|
input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag)
|
|
else:
|
|
gemm1_out = tmo.group_gemm(expand_hidden_states,
|
|
w1,
|
|
token_count[start_expert_id:start_expert_id+expert_size],
|
|
None,
|
|
None,
|
|
None,
|
|
None, tokens)
|
|
if per_token_sq:
|
|
quant_input = quant_input[:, :gemm1_out.shape[-1] // 2] if gated else quant_input[:, :gemm1_out.shape[-1]]
|
|
input_scale = input_scale[:gemm1_out.shape[0]]
|
|
quant_input, input_scale = tmo.moe_quantize(gemm1_out, act_smooth, None,
|
|
token_count[start_expert_id:start_expert_id+expert_size],
|
|
output=quant_input,
|
|
output_scale=input_scale,
|
|
act_mode=act_mode,
|
|
is_gated=gated)
|
|
else:
|
|
act_out = gemm1_out[:, :gemm1_out.shape[-1] // 2] if gated else gemm1_out
|
|
act_out = tmo.moe_active(gemm1_out, act_mode, gated, act_out, bias1, cusum_token_count, start_expert_id, expert_size)
|
|
if cncl_comm > 0:
|
|
raise ValueError("not support communication and computing fusion currently.")
|
|
else:
|
|
if per_token_sq:
|
|
gemm2_out = tmo.smooth_quant_group_gemm(quant_input,
|
|
w2, token_count[start_expert_id:start_expert_id+expert_size],
|
|
None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag)
|
|
else:
|
|
gemm2_out = tmo.group_gemm(act_out,
|
|
w2,
|
|
token_count[start_expert_id:start_expert_id+expert_size],
|
|
None, None, None, None, tokens)
|
|
|
|
output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
|
|
residual, cusum_token_count, start_expert_id,
|
|
expert_size, bias2)
|
|
return output.reshape(ori_input_shape)
|
|
|
|
|
|
def matmul(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
c: Optional[torch.Tensor] = None,
|
|
act_mode: str = 'none',
|
|
alpha: float = 1.0,
|
|
beta: float = .0
|
|
) -> torch.Tensor:
|
|
return tmo.matmul(a, b, bias, c, act_mode, alpha, beta)
|
|
|
|
|
|
def weight_only_quant_matmul(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
zero: torch.Tensor = None,
|
|
bias: torch.Tensor = None,
|
|
c: torch.Tensor = None,
|
|
act_mode: str = "none",
|
|
quant_bit_size: int = 8,
|
|
alpha: float = 1.0,
|
|
beta: float = 1.0
|
|
) -> torch.Tensor:
|
|
assert False, "[weight_only_quant_matmul] is deprecated."
|
|
|
|
|
|
def smooth_quant_matmul(
|
|
a: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
bias: torch.Tensor = None,
|
|
c: torch.Tensor = None,
|
|
act_mode: str = "none",
|
|
alpha: float = 1.0,
|
|
beta: float = 1.0,
|
|
use_hp_active: bool = False,
|
|
b_quant_bit_size: int = 8,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return tmo.scaled_matmul(a, b, a_scale, b_scale, dtype, bias, c, act_mode,
|
|
b_quant_bit_size, alpha, beta, use_hp_active)
|
|
|
|
|
|
def per_token_smooth_quantize(x: torch.Tensor,
|
|
smooth: torch.Tensor,
|
|
zero: torch.Tensor = None,
|
|
token_count: torch.Tensor = None,
|
|
act_mode: str = "none",
|
|
active_coef: float = 1.0,
|
|
is_gated: bool = False
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if act_mode == "none":
|
|
is_gated = False
|
|
if token_count is None:
|
|
output, output_scale = tmo.scaled_quantize(x, smooth, zero, None, torch.int8,
|
|
"dynamic_per_token", act_mode, active_coef,
|
|
is_gated)
|
|
else:
|
|
output, output_scale = tmo.moe_quantize(x, smooth, zero, token_count, None, None, None,
|
|
None, True, act_mode, active_coef, is_gated)
|
|
return (output, output_scale)
|
|
|
|
|
|
def quantize(
|
|
x: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
zero: torch.Tensor = None
|
|
) -> torch.Tensor:
|
|
assert False, "[quantize] is deprecated."
|
|
|
|
|
|
def quant_to_paged_cache(
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
k_cache_quant_scale: torch.Tensor,
|
|
v_cache_quant_scale: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
) -> None:
|
|
if k_cache is not None and k_cache.dtype == torch.uint8:
|
|
k_cache = k_cache.view(torch.float8_e4m3fn)
|
|
if v_cache is not None and v_cache.dtype == torch.uint8:
|
|
v_cache = v_cache.view(torch.float8_e4m3fn)
|
|
return tmo.quant_to_paged_cache(
|
|
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping
|
|
)
|
|
|
|
|
|
def advance_step(num_seqs: int,
|
|
num_queries: int,
|
|
block_size: int,
|
|
input_tokens: torch.Tensor,
|
|
sampled_token_ids: torch.Tensor,
|
|
input_positions: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
TILE_SIZE: int = 64) -> None:
|
|
"""
|
|
Advance a step on MLU for existing inputs for a multi-step runner, which
|
|
will update input_tokens/seq_lens/input_positions/slot_mapping inplace.
|
|
"""
|
|
def verify_tensor(
|
|
name: str,
|
|
tensor: torch.Tensor,
|
|
size_0: int,
|
|
size_1: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
"""
|
|
Auxiliary function to check whether input is valid.
|
|
"""
|
|
size_0_cond = (size_0 == -1 or tensor.size(0) == size_0)
|
|
size_1_cond = (size_1 == -1 or tensor.size(1) == size_1)
|
|
if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype):
|
|
raise ValueError(
|
|
f"The input to advance_step is invalid with tensor name = {name}, "
|
|
f"shape = {tensor.shape}, "
|
|
f"is_cont = {tensor.is_contiguous()}, "
|
|
f"type = {tensor.dtype}, "
|
|
f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}"
|
|
)
|
|
|
|
verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64)
|
|
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64)
|
|
verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32)
|
|
verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32)
|
|
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32)
|
|
verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32)
|
|
|
|
grid = (math.ceil(num_queries / TILE_SIZE), )
|
|
_triton_advance_step[grid](input_tokens,
|
|
sampled_token_ids,
|
|
input_positions,
|
|
seq_lens,
|
|
slot_mapping,
|
|
block_tables,
|
|
block_tables.stride(0),
|
|
num_seqs,
|
|
num_queries,
|
|
block_size,
|
|
TILE_SIZE)
|
|
|
|
|
|
#Moe inner kernels
|
|
def moe_softmax_topk(input: torch.Tensor,
|
|
topk: int,
|
|
normalize: bool = False,
|
|
num_expert_group: int = -1,
|
|
topk_group: int = 0,
|
|
mask: Optional[torch.Tensor] = None,
|
|
normed_by : str = "topk_logit",
|
|
route_scale : float = 1.0,
|
|
reduce_weight: Optional[torch.Tensor] = None,
|
|
expert_id: Optional[torch.Tensor] = None,
|
|
score_bias: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]:
|
|
return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group,
|
|
topk_group, mask, normed_by, route_scale,
|
|
reduce_weight, expert_id, score_bias)
|
|
|
|
def moe_sigmoid_topk(input: torch.Tensor,
|
|
topk: int,
|
|
normalize: bool = False,
|
|
num_expert_group: int = -1,
|
|
topk_group: int = 0,
|
|
route_scale: float = 1.0,
|
|
score_bias: Optional[torch.Tensor] = None,
|
|
reduce_weight: Optional[torch.Tensor] = None,
|
|
expert_id: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]:
|
|
return tmo.moe_sigmoid_topk(input, topk, normalize, num_expert_group,
|
|
topk_group, route_scale = route_scale,
|
|
score_bias = score_bias,
|
|
reduce_weight=reduce_weight,
|
|
expert_id=expert_id)
|
|
|
|
def moe_softplus_topk(
|
|
input: torch.Tensor,
|
|
topk: int,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
tid2eid: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
route_scale: float = 1.0,
|
|
reduce_weight: Optional[torch.Tensor] = None,
|
|
expert_id: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return tmo.moe_softplus_topk(
|
|
input,
|
|
topk,
|
|
input_ids,
|
|
tid2eid,
|
|
bias,
|
|
route_scale,
|
|
reduce_weight,
|
|
expert_id,
|
|
)
|
|
|
|
def moe_gen_idx(expert_id: torch.Tensor,
|
|
expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return tmo.moe_gen_idx(expert_id, expert_num)
|
|
|
|
def moe_expand_input(input: torch.Tensor,
|
|
gather_idx: torch.Tensor,
|
|
cusum_token_count: Optional[torch.Tensor] = None,
|
|
start_expert_id: int = 0,
|
|
expert_size: int = 0) -> torch.Tensor:
|
|
return tmo.moe_expand_input(input, gather_idx,
|
|
cusum_token_count,
|
|
start_expert_id, expert_size)
|
|
|
|
def moe_active(input: torch.Tensor,
|
|
act_mode: str,
|
|
is_gated: bool,
|
|
output: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
cusum_token_count: Optional[torch.Tensor] = None,
|
|
start_expert_id: int = 0,
|
|
expert_size: int = 0) -> torch.Tensor:
|
|
return tmo.moe_active(input, act_mode, is_gated, output,
|
|
bias, cusum_token_count,
|
|
start_expert_id, expert_size)
|
|
|
|
def group_gemm(a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
m_list: torch.Tensor,
|
|
expand_idx: Optional[torch.Tensor],
|
|
c: Optional[torch.Tensor],
|
|
alpha: Optional[torch.Tensor],
|
|
beta: Optional[torch.Tensor],
|
|
max_m: int = 0,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return tmo.group_gemm(a, b, m_list, expand_idx,
|
|
c, alpha, beta, max_m, d=output)
|
|
|
|
def smooth_quant_group_gemm(a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
m_list: torch.Tensor,
|
|
expand_idx: Optional[torch.Tensor],
|
|
c: Optional[torch.Tensor],
|
|
alpha: Optional[torch.Tensor],
|
|
beta: Optional[torch.Tensor],
|
|
a_scale: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
dtype,
|
|
max_m: int = 0,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta,
|
|
a_scale, b_scale, dtype, max_m, d=output)
|
|
|
|
def moe_combine_result(input: torch.Tensor,
|
|
reduce_weight: torch.Tensor,
|
|
gather_ids: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
cusum_token_count: Optional[torch.Tensor],
|
|
start_expert_id: int,
|
|
expert_size: int,
|
|
bias: Optional[torch.Tensor] = None,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return tmo.moe_combine_result(input, reduce_weight, gather_ids,
|
|
residual, cusum_token_count,
|
|
start_expert_id, expert_size, bias, output=output)
|
|
|
|
def moe_quantize(x: torch.Tensor,
|
|
smooth: torch.Tensor,
|
|
zero: Optional[torch.Tensor] = None,
|
|
token_count: Optional[torch.Tensor] = None,
|
|
gather_index: Optional[torch.Tensor] = None,
|
|
gather_index_start_position: Optional[torch.Tensor] = None,
|
|
output: Optional[torch.Tensor] = None,
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
dynamic_quant: bool = True,
|
|
act_mode: str = "none",
|
|
active_coef: float = 1.0,
|
|
is_gated: bool = False,
|
|
quant_type: torch.dtype = torch.int8
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position,
|
|
output, output_scale, dynamic_quant, act_mode, active_coef, is_gated, quant_type)
|
|
|
|
|
|
def dequant_from_paged_cache(key: torch.Tensor,
|
|
value: Optional[torch.Tensor],
|
|
key_cache: torch.Tensor,
|
|
value_cache: Optional[torch.Tensor],
|
|
key_cache_quant_scale: torch.Tensor,
|
|
value_cache_quant_scale: Optional[torch.Tensor],
|
|
context_lengths: torch.Tensor,
|
|
max_context_len: int,
|
|
context_seq_offset: Optional[torch.Tensor],
|
|
block_tables: torch.Tensor,
|
|
quant_mode: int = 0,
|
|
quant_bit: int = 8) -> None:
|
|
tmo.dequant_from_paged_cache(
|
|
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale,
|
|
context_lengths, max_context_len, context_seq_offset, block_tables, quant_mode, quant_bit)
|
|
|
|
def random_sample(
|
|
probs: torch.Tensor,
|
|
is_gumbel_max: bool,
|
|
generators: dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
return tmo.random_sample(probs, is_gumbel_max, generators)
|
|
|
|
def rejection_sample(draft_token_ids: torch.Tensor,
|
|
num_draft_tokens: torch.Tensor,
|
|
cu_num_draft_tokens: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
target_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
uniform_rand: torch.Tensor,
|
|
uniform_probs: torch.Tensor,
|
|
max_spec_len: int,
|
|
high_acc: bool = True) -> torch.Tensor:
|
|
return tmo.rejection_sample(
|
|
draft_token_ids, num_draft_tokens, cu_num_draft_tokens, draft_probs,
|
|
target_probs, bonus_token_ids, uniform_rand, uniform_probs, max_spec_len, high_acc)
|
|
|
|
def apply_topkp_v2(logits: torch.Tensor,
|
|
index_in: torch.Tensor,
|
|
temperature_list: torch.Tensor,
|
|
minp_list: torch.Tensor,
|
|
topk_list: torch.Tensor,
|
|
topp_list: torch.Tensor,
|
|
logits_out: Optional[torch.Tensor] = None,
|
|
sorted_logits_out: Optional[torch.Tensor] = None,
|
|
index_out: Optional[torch.Tensor] = None,
|
|
true_select_len: Optional[torch.Tensor] = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return tmo.apply_topkp_v2(logits, index_in, temperature_list, minp_list, topk_list, topp_list,
|
|
logits_out, sorted_logits_out, index_out, true_select_len)
|
|
|
|
|
|
def scaled_quantize(
|
|
input: torch.Tensor,
|
|
scale: Optional[torch.Tensor] = None,
|
|
zero: Optional[torch.Tensor] = None,
|
|
scale_ub: Optional[torch.Tensor] = None,
|
|
quant_type: torch.dtype = torch.int8,
|
|
quant_mode: str = "dynamic_per_token",
|
|
act_mode: str = "none",
|
|
active_coef: float = 1.0,
|
|
is_gated: bool = False
|
|
) -> Tuple[torch.Tensor]:
|
|
"""
|
|
Apply activation and quantization to the input tensor x.
|
|
|
|
Args:
|
|
x (torch.Tensor): The tensor to be quantized, shape is (..., C), must be continuous between 0 and -2 dimensions.
|
|
scale (Optional[torch.Tensor], optional): The scale multipled to the input tensor. Shape is (C) or (1).
|
|
zero (Optional[torch.Tensor], optional): Not supported, must pass None.
|
|
scale_ub (Optional[torch.Tensor], optional): The output_scale upper bound.
|
|
Take effect only if quant_type == torch.float8_e4m3fn and quant_mode == "dynamic_per_token".
|
|
quant_type (optional): Output data type, can be torch.int8, torch.float8_e4m3fn. Defaults to torch.int8.
|
|
quant_mode (str, optional): quantize mode, which can be "dynamic_per_token", "dynamic_per_tensor", "static_per_tensor"
|
|
and "static_per_channel". Defaults to "dynamic_per_token".
|
|
act_mode (str): The mode of activation, must be "none", "gelu", "silu", "swish".
|
|
active_coef(float): The coefficient used in the swish activation. Default is 1.0.
|
|
is_gated (bool): A boolean parameter that indicates whether a gating mechanism is applied. It only
|
|
takes effect when act_mode is not "none".
|
|
|
|
Type:
|
|
input: float, half, bfloat16.
|
|
scale: float.
|
|
scale_ub: float.
|
|
act_mode: str
|
|
active_coef: float
|
|
is_gated: bool
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor]: Returns (output, output_scale) if quant_mode is "dynamic_per_token" or "dynamic_per_tensor",
|
|
otherwise returns output only.
|
|
"""
|
|
return tmo.scaled_quantize(input,
|
|
scale,
|
|
zero,
|
|
scale_ub,
|
|
quant_type,
|
|
quant_mode,
|
|
act_mode,
|
|
active_coef,
|
|
is_gated)
|
|
|
|
def scaled_matmul(a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
a_scale: Optional[torch.Tensor],
|
|
b_scale: torch.Tensor,
|
|
output_dtype: torch.dtype,
|
|
bias: torch.Tensor = None,
|
|
c: torch.Tensor = None,
|
|
act_mode: str = "none",
|
|
quant_bit_size: int = 8,
|
|
alpha: float = 1.0,
|
|
beta: float = 1.0,
|
|
use_hp_active: bool = False,
|
|
a_quant_bit_size: int = 8,
|
|
a_calib: Optional[torch.Tensor] = None,
|
|
b_calib: Optional[torch.Tensor] = None,):
|
|
"""
|
|
Perform quantized matrix multiplication on tensor a and b.
|
|
|
|
Args:
|
|
a (torch.Tensor): Shape is (M, K).
|
|
b (torch.Tensor): If quant_bit_size = 8, shape is (N, K).
|
|
If quant_bit_size = 4, shape is (N, K//2).
|
|
a_scale (Optional[torch.Tensor]): Shape can be (M).
|
|
b_scale (torch.Tensor): If use groupwise quantization, shape must be (N, group_num), data type must be
|
|
the same as a; otherwise shape must be (N), data type must be float.
|
|
output_dtype (torch.dtype): Specify the data type of output, must be torch.half or torch.bfloat16.
|
|
bias (torch.Tensor, optional): Shape is (N).
|
|
c (torch.Tensor, optional): Shape is (M, N).
|
|
act_mode (str, optional): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. If use groupwise
|
|
quantization, act_mode must be 'none'.
|
|
quant_bit_size (int, optional): The data format of b. Defaults to 8.
|
|
alpha (float, optional): coefficient of acted. Defaults to 1.0.
|
|
beta (float, optional): coefficient of c. Defaults to 1.0.
|
|
use_hp_active (bool, optional): Describing the algorithm that used in the implementation of the activation function.
|
|
When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation.
|
|
Defaults to False.
|
|
a_quant_bit_size(int, optional):The data format of a. Defaults to -1.
|
|
a_calib (Optional[torch.Tensor]): The calibration of a, shape can be (M, 2).
|
|
b_calib (Optional[torch.Tensor]): The calibration of b, shape can be (M, 2).
|
|
|
|
Type:
|
|
a: int8, half, bfloat16, float8_e4m3fn, int4X2
|
|
a_scale: float
|
|
b: int8, float8_e4m3fn, int4X2
|
|
b_scale: float, half, bfloat16
|
|
bias: half, float, bfloat16
|
|
c: half, float, bfloat16
|
|
output: half, bfloat16
|
|
a_calib: float
|
|
b_calib: float
|
|
|
|
Returns:
|
|
A tensor with the shape of (M, N).
|
|
"""
|
|
return tmo.scaled_matmul(a,
|
|
b,
|
|
a_scale,
|
|
b_scale,
|
|
output_dtype,
|
|
bias,
|
|
c,
|
|
act_mode,
|
|
quant_bit_size,
|
|
alpha,
|
|
beta,
|
|
use_hp_active,
|
|
a_quant_bit_size,
|
|
a_calib,
|
|
b_calib,)
|
|
|
|
def fused_mla_kv(kv: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
position_id: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
kv_cache_scale: Optional[torch.Tensor],
|
|
slot_mapping: Optional[torch.Tensor],
|
|
cache_bs_id: Optional[torch.Tensor] = None,
|
|
cache_seq_offset: Optional[torch.Tensor] = None,
|
|
is_paged_cache: bool = True,
|
|
eps: float = 1e-5,
|
|
interleaved: bool = True):
|
|
quant_mode = "static_per_channel" if kv_cache_scale is None else "dynamic_per_token"
|
|
return tmo.fused_mla_kv(
|
|
kv, sin, cos, position_id, gamma, kv_cache, kv_cache_scale, slot_mapping, cache_bs_id,
|
|
cache_seq_offset,
|
|
quant_mode=quant_mode,
|
|
is_paged_cache=is_paged_cache,
|
|
eps=eps,
|
|
interleaved=interleaved,
|
|
)
|
|
|
|
def fused_mla_q(q: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
smooth_quant_scale: torch.Tensor,
|
|
weight_b: torch.Tensor,
|
|
weight_b_scale: torch.Tensor,
|
|
weight_c: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
position_id: torch.Tensor,
|
|
output: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-6,
|
|
interleaved: bool = True,
|
|
output_quant_mode: str = 'none',
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
output_norm: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
return tmo.fused_mla_q(
|
|
q, gamma, smooth_quant_scale, weight_b, weight_b_scale, weight_c, sin, cos, position_id,
|
|
output, eps, interleaved, output_quant_mode, output_scale,
|
|
store_norm=(output_norm is not None),
|
|
output_norm= output_norm,
|
|
)
|
|
|
|
|
|
def gather_cache(
|
|
kv_cache: List[torch.Tensor], # [[1, num_blocks, num_kv_heads, block_size, head_size]
|
|
# [1, num_blocks, num_kv_heads, block_size] if kv_cache_dtype=int8]
|
|
dst: torch.Tensor, # [tot_tokens, entrys...]
|
|
block_table: torch.Tensor, # [batch, block_indices]
|
|
cu_seq_lens: torch.Tensor, # [batch+1]
|
|
batch_size: int,
|
|
seq_starts: torch.Tensor = None, # Optional: [batch]
|
|
kv_cache_dtype: str = 'auto',
|
|
) -> None:
|
|
"""
|
|
Gathers sequences from src_cache into dst based on block_table and cu_seq_lens.
|
|
|
|
Args:
|
|
src_cache: Source KV cache tensor of shape [[1, num_blocks, num_kv_heads, block_size, head_size],
|
|
[1, num_blocks, num_kv_heads, block_size] if cache_dtype=int8].
|
|
dst: Destination tensor of shape [tot_tokens, entrys...].
|
|
block_table: Tensor of shape [batch, block_indices] mapping sequences to blocks.
|
|
cu_seq_lens: Tensor of shape [batch+1] with cumulative sequence lengths.
|
|
batch_size: Number of sequences in the batch.
|
|
seq_starts: Optional tensor of shape [batch] for block index offsets.
|
|
"""
|
|
assert len(kv_cache) > 0 and kv_cache[0].numel() > 0, "kv cache can't be empty in gather_cache"
|
|
src_cache = kv_cache[0][0]
|
|
# Validate inputs
|
|
assert src_cache.device == dst.device == block_table.device == cu_seq_lens.device, \
|
|
"All tensors must be on the same device"
|
|
assert block_table.dtype == torch.int32, "block_table must be int32"
|
|
assert cu_seq_lens.dtype == torch.int32, "cu_seq_lens must be int32"
|
|
quant_kv_cache = kv_cache_dtype != 'auto'
|
|
if not quant_kv_cache:
|
|
assert src_cache.dtype == dst.dtype, "src_cache and dst must have the same dtype when no quantized"
|
|
if seq_starts is not None:
|
|
assert seq_starts.dtype == torch.int32, "seq_starts must be int32"
|
|
assert seq_starts.device == src_cache.device, "seq_starts must be on the same device"
|
|
|
|
# Extract dimensions
|
|
num_blocks, num_kv_heads, block_size, head_size = src_cache.shape
|
|
# When using MLA during decode it becomes MQA, the num_kv_heads is fixed to 1,
|
|
# so src_cache can be view to [num_blocks, block_size, head_size]
|
|
assert num_kv_heads == 1, "mla force num_kv_heads to 1"
|
|
src_cache = src_cache.view(num_blocks, block_size, -1)
|
|
entry_shape = src_cache.shape[2:] # ENTRIES...
|
|
tot_tokens = cu_seq_lens[-1]
|
|
assert tot_tokens > 0, "tot_tokens should > 0"
|
|
assert tot_tokens <= dst.shape[0], "tot_tokens should <= dst.shape[0]"
|
|
dst_cache = dst[:tot_tokens]
|
|
|
|
# Ensure cu_seq_lens matches batch_size
|
|
assert cu_seq_lens.size(0) == batch_size + 1, "cu_seq_lens must have batch_size + 1 elements"
|
|
|
|
# Compute sequence lengths
|
|
seq_lens = cu_seq_lens[1:] - cu_seq_lens[:-1] # [BATCH]
|
|
tot_blocks_per_seq = (seq_lens + block_size - 1) // block_size # ceil_div
|
|
|
|
# Handle seq_starts offset
|
|
block_offsets = torch.zeros(batch_size, dtype=torch.int32, device=src_cache.device)
|
|
if seq_starts is not None:
|
|
block_offsets = seq_starts // block_size
|
|
|
|
# Flatten src_cache for easier indexing: [NUM_BLOCKS * BLOCK_SIZE, ENTRIES...]
|
|
src_flat = src_cache.view(num_blocks * block_size, *entry_shape)
|
|
|
|
# Prepare output indices
|
|
dst_indices = []
|
|
for bid in range(batch_size):
|
|
seq_len = seq_lens[bid]
|
|
if seq_len <= 0:
|
|
continue
|
|
seq_start = cu_seq_lens[bid]
|
|
tot_blocks = tot_blocks_per_seq[bid]
|
|
offset = block_offsets[bid]
|
|
|
|
# Compute block indices for this sequence
|
|
block_ids = block_table[bid, offset:offset + tot_blocks]
|
|
|
|
# Compute token indices within blocks
|
|
token_indices = torch.arange(seq_len, device=src_cache.device)
|
|
block_indices = token_indices // block_size
|
|
within_block = token_indices % block_size
|
|
|
|
# Map to src_flat indices
|
|
src_indices = block_ids[block_indices] * block_size + within_block
|
|
dst_indices.append(src_indices)
|
|
|
|
# Concatenate all indices
|
|
dst_indices = torch.cat(dst_indices)
|
|
|
|
# Gather data
|
|
dst_flat = src_flat[dst_indices]
|
|
if quant_kv_cache:
|
|
src_cache_scale = kv_cache[1][0]
|
|
src_scale_flat = src_cache_scale.view(num_blocks * block_size)
|
|
dst_scale_flat = src_scale_flat[dst_indices]
|
|
dst_flat = dst_flat * dst_scale_flat.unsqueeze(-1)
|
|
|
|
dst_cache.view(-1, *entry_shape).copy_(dst_flat.view(tot_tokens, *entry_shape))
|
|
|
|
|
|
def merge_attn_states(
|
|
output: torch.Tensor,
|
|
prefix_output: torch.Tensor,
|
|
prefix_lse: torch.Tensor,
|
|
suffix_output: torch.Tensor,
|
|
suffix_lse: torch.Tensor,
|
|
output_lse: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
"""
|
|
Merges partial attention states (prefix and suffix) into a single output.
|
|
Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005.
|
|
|
|
Args:
|
|
output: Output tensor of shape [num_tokens, num_query_heads, head_size].
|
|
prefix_output: Prefix attention output, same shape as output.
|
|
prefix_lse: Prefix log-sum-exp, shape [num_query_heads, num_tokens].
|
|
suffix_output: Suffix attention output, same shape as output.
|
|
suffix_lse: Suffix log-sum-exp, same shape as prefix_lse.
|
|
output_lse: Optional output log-sum-exp, same shape as prefix_lse.
|
|
"""
|
|
# Input validation
|
|
assert output.shape == prefix_output.shape == suffix_output.shape, \
|
|
"Output and input tensors must have the same shape"
|
|
assert prefix_lse.shape == suffix_lse.shape, \
|
|
"Prefix and suffix LSE tensors must have the same shape"
|
|
if output_lse is not None:
|
|
assert output_lse.shape == prefix_lse.shape, \
|
|
"Output LSE must have the same shape as input LSE tensors"
|
|
|
|
# Handle inf values (replace inf with -inf for consistency)
|
|
p_lse = torch.where(
|
|
prefix_lse == float('inf'),
|
|
torch.tensor(float('-inf'), device=prefix_lse.device),
|
|
prefix_lse
|
|
)
|
|
s_lse = torch.where(
|
|
suffix_lse == float('inf'),
|
|
torch.tensor(float('-inf'), device=suffix_lse.device),
|
|
suffix_lse
|
|
)
|
|
|
|
# Compute maximum LSE for numerical stability
|
|
max_lse = torch.maximum(p_lse, s_lse) # Shape: [num_query_heads, num_tokens]
|
|
|
|
# Normalize LSE terms
|
|
p_lse = p_lse - max_lse # Shape: [num_query_heads, num_tokens]
|
|
s_lse = s_lse - max_lse # Shape: [num_query_heads, num_tokens]
|
|
|
|
# Compute sum of exponentials
|
|
out_se = torch.exp(p_lse) + torch.exp(s_lse) # Shape: [num_query_heads, num_tokens]
|
|
|
|
# Compute output_lse if provided
|
|
if output_lse is not None:
|
|
output_lse.copy_(torch.log(out_se) + max_lse)
|
|
|
|
# Compute scaling factors
|
|
p_scale = torch.exp(p_lse) / out_se # Shape: [num_query_heads, num_tokens]
|
|
s_scale = torch.exp(s_lse) / out_se # Shape: [num_query_heads, num_tokens]
|
|
|
|
# Reshape scales for broadcasting
|
|
p_scale = p_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1]
|
|
s_scale = s_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1]
|
|
|
|
# Transpose outputs to match scaling dimensions
|
|
prefix_output = prefix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size]
|
|
suffix_output = suffix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size]
|
|
|
|
# Compute merged output
|
|
out = prefix_output * p_scale + suffix_output * s_scale # Shape: [num_query_heads, num_tokens, head_size]
|
|
|
|
# Transpose back and store in output
|
|
output.copy_(out.permute(1, 0, 2)) # Shape: [num_tokens, num_query_heads, head_size]
|
|
|
|
def moe_all2all_create(dispatch_token_byte: int,
|
|
combine_token_byte: int,
|
|
max_expert_num: int,
|
|
max_token_num: int,
|
|
rank: int,
|
|
nrank: int) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Create the handle of MOE All-to-All communication.
|
|
API call order:
|
|
1.Call torch_mlu_ops.moe_all2all_create(...) to obtain the CNCLEP handle and buffer tensor for All-to-All communication. Only needs to be done once.
|
|
2.Gather all_exchange_info by performing an All-Gather operation on exchange_info across nrank processes. Only needs to be done once.
|
|
3.Call torch.distributed.barrier() to ensure step 2 finish. Only needs to be done once.
|
|
4.Call torch_mlu_ops.moe_all2all_init(...) to configure the all_exchange_info into the handle. Only needs to be done once.
|
|
5.Call torch_mlu_ops.moe_all2all_dispatch(...) to route tokens to their designated experts.
|
|
6.Call torch_mlu_ops.moe_all2all_combine(...) to restore tokens to their original locations.
|
|
7.Call torch_mlu_ops.moe_all2all_destroy(...) to release the CNCLEP handle. Only needs to be done once.
|
|
|
|
Args:
|
|
dispatch_token_byte (int): Byte size of a single token for dispatch All-to-All operation.
|
|
combine_token_byte (int): Byte size of a single token for combine All-to-All operation.
|
|
max_expert_num (int): Maximum number of experts participating in the All-to-All operation.
|
|
max_token_num (int): Maximum number of tokens to be processed.
|
|
rank (int): Rank ID of the current process [0~nrank-1].
|
|
nrank (int): Total number of processes in the distributed group.
|
|
|
|
Return:
|
|
A tuple of (handle, exchange_info_size, exchange_info, dispatch_send, dispatch_recv, combine_send and combine_recv).
|
|
handle: The CNCLEP handle with type of integer.
|
|
exchange_info_size: The size of exchange_info.
|
|
exchange_info: CPU tensor, shape is [exchange_info_size], and data type is torch.int8.
|
|
dispatch_send: MLU tensor, shape is [max_token_num * dispatch_token_byte], and data type is torch.int8.
|
|
dispatch_recv: MLU tensor, shape is [nrank * max_token_num * dispatch_token_byte], and data type is torch.int8.
|
|
combine_send: MLU tensor, shape is [max_token_num * combine_token_byte], and data type is torch.int8.
|
|
combine_recv: MLU tensor, shape is [nrank * max_token_num * combine_token_byte], and data type is torch.int8.
|
|
"""
|
|
return tmo.moe_all2all_create(dispatch_token_byte, combine_token_byte, max_expert_num, max_token_num, rank, nrank)
|
|
|
|
def moe_all2all_init(handle: int,
|
|
all_exchange_info: torch.Tensor) -> None:
|
|
tmo.moe_all2all_init(handle, all_exchange_info)
|
|
|
|
def moe_all2all_destroy(handle: int) -> None:
|
|
tmo.moe_all2all_destroy(handle)
|
|
|
|
def moe_all2all_dispatch(handle: int,
|
|
token_byte: int,
|
|
token_num: int,
|
|
send_layout: torch.Tensor,
|
|
send_token_num: torch.Tensor,
|
|
recv_layout: torch.Tensor,
|
|
recv_token_num: torch.Tensor,
|
|
send_token: Optional[torch.Tensor] = None,
|
|
recv_token: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
tmo.moe_all2all_dispatch(handle, token_byte, token_num, send_layout, send_token_num, recv_layout, recv_token_num, send_token, recv_token)
|
|
|
|
def moe_all2all_combine(handle: int,
|
|
token_byte: int,
|
|
token_num: int,
|
|
send_src_layout: torch.Tensor,
|
|
send_dst_layout: torch.Tensor,
|
|
send_token: Optional[torch.Tensor] = None,
|
|
recv_token: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
tmo.moe_all2all_combine(handle, token_byte, token_num, send_src_layout, send_dst_layout, send_token, recv_token)
|
|
|
|
def gather_split(input: torch.Tensor,
|
|
gather_index: torch.Tensor,
|
|
valid_token_num: torch.Tensor,
|
|
output1: torch.Tensor,
|
|
output2: Optional[torch.Tensor] = None) -> None:
|
|
tmo.gather_split(input,
|
|
gather_index,
|
|
valid_token_num,
|
|
output1,
|
|
output2)
|
|
|
|
def moe_all2all_gen_send_layout(token_count: torch.Tensor,
|
|
nrank: int) -> torch.Tensor:
|
|
return tmo.moe_all2all_gen_send_layout(token_count, nrank)
|
|
|
|
def moe_all2all_gen_gather_index(token_num: torch.Tensor, pad_num: int,
|
|
return_cusum_token_count: bool = False):
|
|
if not return_cusum_token_count:
|
|
gather_by_expert_index, gather_by_rank_index, token_count, token_sum = \
|
|
tmo.moe_all2all_gen_gather_index(token_num, pad_num)
|
|
return gather_by_expert_index, gather_by_rank_index, token_count, token_sum
|
|
else:
|
|
gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count = \
|
|
tmo.moe_all2all_gen_gather_index(token_num, pad_num, return_cusum_token_count=True)
|
|
return gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count
|
|
|
|
def reshape_from_cache(
|
|
key: torch.Tensor,
|
|
value: Optional[torch.Tensor],
|
|
key_cache: torch.Tensor,
|
|
value_cache: Optional[torch.Tensor],
|
|
context_lengths: torch.Tensor,
|
|
max_context_len: int,
|
|
context_seq_offset: Optional[torch.Tensor] = None,
|
|
block_tables: Optional[torch.Tensor] = None,
|
|
cache_seq_offset: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
tmo.reshape_from_cache(
|
|
key=key,
|
|
value=value,
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
context_lengths=context_lengths,
|
|
max_context_len=max_context_len,
|
|
context_seq_offset=context_seq_offset,
|
|
block_tables=block_tables,
|
|
cache_seq_offset=cache_seq_offset,
|
|
)
|
|
|
|
|
|
def masked_indexer_select_paged_kv(query: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
kv_cache_block_table: torch.Tensor,
|
|
cu_seq_q_lens: Optional[torch.Tensor],
|
|
cu_seq_k_lens: Optional[torch.Tensor],
|
|
k_context_lens: Optional[torch.Tensor],
|
|
k_cache_block_table: Optional[torch.Tensor],
|
|
is_prefill: bool,
|
|
index_topk: int,
|
|
kv_cache_block_size: int,
|
|
softmax_scale: float,
|
|
q_scale: Optional[torch.Tensor] = None,
|
|
k_scale_cache: Optional[torch.Tensor] = None,
|
|
sparse_block_table: Optional[torch.Tensor] = None,
|
|
sparse_context_lens: Optional[torch.Tensor] = None):
|
|
tmo.masked_indexer_select_paged_kv(query=query,
|
|
k_cache=k_cache,
|
|
weights=weights,
|
|
kv_cache_block_table=kv_cache_block_table,
|
|
cu_seq_q_lens=cu_seq_q_lens,
|
|
cu_seq_k_lens=cu_seq_k_lens,
|
|
k_context_lens=k_context_lens,
|
|
k_cache_block_table=k_cache_block_table,
|
|
is_prefill=is_prefill,
|
|
index_topk=index_topk,
|
|
kv_cache_block_size=kv_cache_block_size,
|
|
softmax_scale=softmax_scale,
|
|
q_scale=q_scale,
|
|
k_scale_cache=k_scale_cache,
|
|
sparse_block_table=sparse_block_table,
|
|
sparse_context_lens=sparse_context_lens)
|
|
|
|
def masked_indexer_select_paged_kv_prefill(
|
|
query: torch.Tensor,
|
|
key_value: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
kv_cache_block_table: torch.Tensor,
|
|
cu_seq_q_lens: torch.Tensor,
|
|
cu_seq_k_lens: torch.Tensor,
|
|
index_topk: int,
|
|
kv_cache_block_size: int,
|
|
softmax_scale: float,
|
|
q_scale: Optional[torch.Tensor] = None,
|
|
k_scale_cache: Optional[torch.Tensor] = None,
|
|
sparse_block_table: Optional[torch.Tensor] = None,
|
|
sparse_context_lens: Optional[torch.Tensor] = None,
|
|
kv_cache_block_table_offset: Optional[torch.Tensor] = None,
|
|
compress_ratio: int = 1,
|
|
):
|
|
return tmo.masked_indexer_select_paged_kv(
|
|
query=query,
|
|
k_cache=key_value,
|
|
weights=weights,
|
|
kv_cache_block_table=kv_cache_block_table,
|
|
cu_seq_q_lens=cu_seq_q_lens,
|
|
cu_seq_k_lens=cu_seq_k_lens,
|
|
k_context_lens=None,
|
|
k_cache_block_table=None,
|
|
is_prefill=True,
|
|
index_topk=index_topk,
|
|
kv_cache_block_size=kv_cache_block_size,
|
|
softmax_scale=softmax_scale,
|
|
q_scale=q_scale,
|
|
k_scale_cache=k_scale_cache,
|
|
sparse_block_table=sparse_block_table,
|
|
sparse_context_lens=sparse_context_lens,
|
|
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
|
compress_ratio=compress_ratio,
|
|
is_score_float=True,
|
|
)
|
|
|
|
def masked_indexer_select_paged_kv_decode(
|
|
query: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
kv_cache_block_table: torch.Tensor,
|
|
k_context_lens: Optional[torch.Tensor],
|
|
k_cache_block_table: Optional[torch.Tensor],
|
|
index_topk: int,
|
|
kv_cache_block_size: int,
|
|
softmax_scale: float,
|
|
q_scale: Optional[torch.Tensor] = None,
|
|
k_scale_cache: Optional[torch.Tensor] = None,
|
|
sparse_block_table: Optional[torch.Tensor] = None,
|
|
sparse_context_lens: Optional[torch.Tensor] = None,
|
|
kv_cache_block_table_offset: Optional[torch.Tensor] = None,
|
|
compress_ratio: int = 1,
|
|
):
|
|
query_len = query.shape[1]
|
|
#k_context_lens = k_context_lens // compress_ratio
|
|
return tmo.masked_indexer_select_paged_kv(
|
|
query=query,
|
|
k_cache=k_cache,
|
|
weights=weights,
|
|
kv_cache_block_table=kv_cache_block_table,
|
|
cu_seq_q_lens=None,
|
|
cu_seq_k_lens=None,
|
|
k_context_lens=k_context_lens,
|
|
k_cache_block_table=k_cache_block_table,
|
|
is_prefill=False,
|
|
index_topk=index_topk,
|
|
kv_cache_block_size=kv_cache_block_size,
|
|
softmax_scale=softmax_scale,
|
|
q_scale=q_scale,
|
|
k_scale_cache=k_scale_cache,
|
|
sparse_block_table=sparse_block_table,
|
|
sparse_context_lens=sparse_context_lens,
|
|
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
|
compress_ratio=compress_ratio,
|
|
is_score_float=True,
|
|
)
|
|
|
|
|
|
def concat_block_table(
|
|
first_block_table: torch.Tensor,
|
|
first_context_lens: torch.Tensor,
|
|
second_block_table: torch.Tensor,
|
|
second_context_lens: torch.Tensor,
|
|
new_block_table: Optional[torch.Tensor] = None,
|
|
new_context_lens: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Concatenate two different block tables, return the concatenated result.
|
|
Math:
|
|
new_context_lens = first_context_lens + second_context_lens
|
|
total_seq = first_context_lens.size(0)
|
|
for i in range(total_seq):
|
|
new_block_table[i, :first_context_lens[i]] = first_block_table[i, :first_context_lens[i]]
|
|
new_block_table[i, first_context_lens[i]:first_context_lens[i]+second_context_lens[i]] = second_block_table[i, :second_context_lens[i]]
|
|
Args:
|
|
first_block_table (torch.Tensor):
|
|
The first block table of shape `[total_seq, first_max_blkn]`.
|
|
first_context_lens (torch.Tensor):
|
|
The context lens of the first block table of shape `[total_seq,]`.
|
|
second_block_table (torch.Tensor):
|
|
The second block table of shape `[total_seq, second_max_blkn]`.
|
|
second_context_lens (torch.Tensor):
|
|
The context lens of the second block table of shape `[total_seq,]`.
|
|
new_block_table (Optional[torch.Tensor]):
|
|
The new block table of shape `[total_seq, max_new_block_number]`.
|
|
if not None, the max_new_block_number must be large enough for the concatenated block_table
|
|
Default: `None`.
|
|
new_context_lens (Optional[torch.Tensor]):
|
|
The new context lens of shape `[total_seq,]`. Default: `None`.
|
|
|
|
Returns:
|
|
new_block_table (torch.Tensor):
|
|
The concatenated block table of shape `[total_seq, max_new_block_number]`.
|
|
new_context_lens (torch.Tensor):
|
|
The new context lens of shape `[total_seq,]`, equals first_context_lens + second_context_lens
|
|
Type:
|
|
INT32
|
|
"""
|
|
return tmo.concat_block_table(
|
|
first_block_table,
|
|
first_context_lens,
|
|
second_block_table,
|
|
second_context_lens,
|
|
new_block_table,
|
|
new_context_lens,
|
|
)
|
|
|
|
def fused_mhc_post(
|
|
x: torch.Tensor, # (N, D) float|bf16
|
|
residual: torch.Tensor, # (N, HC, D) float|bf16
|
|
post: torch.Tensor, # (N, HC) 固定为float
|
|
comb: torch.Tensor, # (N, HC, HC) 固定为float
|
|
compute_rms: bool,
|
|
eps: float,
|
|
output: torch.Tensor = None, # (N, HC, D) 同输入类型
|
|
output_rms = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
|
|
Math:
|
|
output = post * x + (comb * residual).sum(dim=1)
|
|
output_rms = rsqrt(x.square().mean(dim=-1))
|
|
|
|
Args:
|
|
x (torch.Tensor): Shape is [N, D].
|
|
residual (torch.Tensor): Shape is [N, HC, D].
|
|
post (torch.Tensor): Shape is [N, HC].
|
|
comb (torch.Tensor): Shape is [N, HC, HC].
|
|
compute_rms (bool): Whether to compute output_rms.
|
|
eps (float): The eps of normalization.
|
|
output (torch.Tensor, optional): Shape is [N, HC, D]. Defaults to None.
|
|
output_rms (torch.Tensor, optional): Shape is [N]. Defaults to None.
|
|
|
|
Returns:
|
|
output, output_rms
|
|
|
|
Limitation:
|
|
D must be 4096.
|
|
HC must be 4.
|
|
"""
|
|
out = tmo.fused_mhc_post(
|
|
x,
|
|
residual,
|
|
post,
|
|
comb,
|
|
compute_rms,
|
|
eps,
|
|
output,
|
|
output_rms,
|
|
)
|
|
|
|
return out if compute_rms else (out, None)
|
|
|
|
def fused_compress_multi_kv(kv: torch.Tensor, # (BS, D) float|bf16
|
|
score: torch.Tensor, # (BS, D) float|bf16
|
|
kv_state: torch.Tensor, # (max_B, coff * R, D) float
|
|
score_state: torch.Tensor, # (max_B, coff * R, D) float
|
|
batch_ids: torch.Tensor, # (B,) int32
|
|
cu_seqlens: torch.Tensor, # (B,) int32
|
|
ape: torch.Tensor, # (R, D) float
|
|
max_seqlen:int,
|
|
overlap: bool,
|
|
compressed_kv: torch.Tensor # (BS, head_dim) float|bf16
|
|
):
|
|
tmo.fused_compress_multi_kv(
|
|
kv = kv,
|
|
score = score,
|
|
kv_state = kv_state,
|
|
score_state = score_state,
|
|
cu_seqlens = cu_seqlens,
|
|
batch_ids = batch_ids,
|
|
ape = ape,
|
|
max_seqlen = max_seqlen,
|
|
overlap = overlap,
|
|
compressed_kv = compressed_kv,
|
|
)
|
|
|
|
def fused_compress_single_kv(
|
|
kv: torch.Tensor, # (T, D) float|bf16
|
|
score: torch.Tensor, # (T, D) float|bf16
|
|
position: torch.Tensor, # (B,) int32
|
|
ape: torch.Tensor, # (ratio, D) float|bf16
|
|
kv_state: torch.Tensor, # (B, R, D) float|bf16
|
|
score_state: torch.Tensor, # (B, R, D) float|bf16
|
|
gamma: torch.Tensor, # (d)
|
|
sin: torch.Tensor, # (-1, rope_dim)
|
|
cos: torch.Tensor, # (-1, rope_dim)
|
|
hadamard_matrix: Optional[torch.Tensor], # (d, d)
|
|
slot_mapping: torch.Tensor, # (B,) int32
|
|
kv_cache: torch.Tensor, # (-1, BLKS, head_dim) bf16|int8|fp8
|
|
kv_cache_scale: Optional[torch.Tensor], # (-1, BLKS) float
|
|
eps: float,
|
|
overlap: bool,
|
|
rotate: bool,
|
|
state_idx: torch.Tensor,
|
|
cu_query_len: torch.Tensor | None = None,
|
|
):
|
|
"""
|
|
|
|
Math:
|
|
|
|
|
|
Args:
|
|
kv (torch.Tensor): Shape is [B, S, D].
|
|
score (torch.Tensor): Shape is [B, S, D].
|
|
position (torch.Tensor): Shape is [B].
|
|
ape (torch.Tensor): Shape is [ratio, D].
|
|
kv_state (torch.Tensor): Shape is [max_B, R, D].
|
|
score_state (torch.Tensor): Shape is [max_B, R, D].
|
|
gamma (torch.Tensor): Shape is [head_dim].
|
|
sin (torch.Tensor): Shape is [table_len, rope_dim].
|
|
cos (torch.Tensor): Shape is [table_len, rope_dim].
|
|
hadamard_matrix (torch.Tensor): Shape is [head_dim, head_dim].
|
|
slot_mapping (torch.Tensor): Shape is [B].
|
|
kv_cache (torch.Tensor): Shape is [cache_len, block_size, hs].
|
|
kv_cache_scale (torch.Tensor): Shape is [cache_len, block_size].
|
|
eps (flost): The eps of normalization.
|
|
overlap (bool): Whether to overlap.
|
|
rotate (bool): Whether to rotate.
|
|
|
|
Type:
|
|
kv: BF16, FP32
|
|
score: same as kv
|
|
position: INT32
|
|
ape: FP32
|
|
kv_state: FP32
|
|
score_state: FP32
|
|
gamma: same as kv
|
|
sin: same as kv
|
|
cos: same as kv
|
|
hadamard_matrix: same as kv
|
|
slot_mapping: INT32
|
|
kv_cache: BF16, FP32
|
|
kv_cache_scale: FP32
|
|
|
|
Returns:
|
|
Only support inplace outputs, include kv_state, score_state, kv_cache, kv_cache_scale
|
|
|
|
Note:
|
|
coff = overlap + 1
|
|
D = coff * head_dim
|
|
R = coff * ratio
|
|
"""
|
|
token_num, coff_dim = kv.shape
|
|
|
|
# TODO: force user_tmo = 0 after supporting mtp.
|
|
bsz = state_idx.numel()
|
|
kv = kv.unsqueeze(1)
|
|
score = score.unsqueeze(1)
|
|
if kv_cache.dim() == 4:
|
|
paged_num, head_num, block_size, head_dim = kv_cache.shape
|
|
assert head_num == 1
|
|
kv_cache = kv_cache.view(paged_num, block_size, head_dim)
|
|
return tmo.fused_compress_single_kv(
|
|
kv=kv,
|
|
score=score,
|
|
position=position,
|
|
state_ids=state_idx,
|
|
ape=ape,
|
|
kv_state=kv_state,
|
|
score_state=score_state,
|
|
gamma=gamma,
|
|
sin=sin,
|
|
cos=cos,
|
|
hadamard_matrix=hadamard_matrix if rotate else None,
|
|
slot_mapping=slot_mapping,
|
|
kv_cache=kv_cache,
|
|
kv_cache_scale=kv_cache_scale,
|
|
eps=eps,
|
|
overlap=overlap,
|
|
)
|
|
|
|
def convertBlockTable(block_table, blks, incseq):
|
|
if blks == 1:
|
|
return block_table
|
|
else:
|
|
expanded = block_table.unsqueeze(1).repeat(1, blks)
|
|
result = expanded * blks + incseq
|
|
return result.flatten()
|
|
|
|
def get_window_block_tables(window_size : int,
|
|
block_size : int, #blocksize of block_table
|
|
seq_k_lens: torch.Tensor,
|
|
query_start_loc: torch.Tensor,
|
|
block_table: Optional[torch.Tensor]=None, # shape (batch, max_blocks)
|
|
window_block_tables:Optional[torch.Tensor]=None, # shape (total_seq, max_blocks)
|
|
window_context_lens:Optional[torch.Tensor]=None): # shape (total_seq)
|
|
tmo.get_window_block_tables(window_block_tables = window_block_tables,
|
|
window_context_lens = window_context_lens,
|
|
seq_k_lens = seq_k_lens,
|
|
query_start_loc = query_start_loc,
|
|
block_table = block_table,
|
|
block_size = block_size,
|
|
window_size = window_size,)
|
|
|
|
def get_compress_block_tables(ratio: int,
|
|
block_size: int,
|
|
seq_k_lens: torch.Tensor, # k lens before compression, shape (batch)
|
|
query_start_loc: torch.Tensor, # shape (batch+1)
|
|
offset: torch.Tensor, # shape (batch)
|
|
block_table: torch.Tensor, # shape (batch, max_blocks)
|
|
compress_block_tables: torch.Tensor, # shape (total_seq, max_blocks)
|
|
compress_context_lens: torch.Tensor): # shape (total_seq)
|
|
tmo.get_compress_block_tables(
|
|
compress_block_tables = compress_block_tables,
|
|
compress_context_lens = compress_context_lens,
|
|
seq_k_lens = seq_k_lens,
|
|
query_start_loc = query_start_loc,
|
|
offset = offset,
|
|
block_table = block_table,
|
|
block_size = block_size,
|
|
ratio = ratio,
|
|
)
|
|
|
|
|
|
def hc_split_sinkhorn(mixes: torch.Tensor,
|
|
hc_scale: torch.Tensor,
|
|
hc_base: torch.Tensor,
|
|
pre_scale: Optional[torch.Tensor] = None,
|
|
hc_mult: int = 4,
|
|
sinkhorn_iter: int = 20,
|
|
eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return tmo.hc_split_sinkhorn(
|
|
mixes = mixes,
|
|
hc_scale = hc_scale,
|
|
hc_base = hc_base,
|
|
pre_scale = pre_scale,
|
|
hc_mult = hc_mult,
|
|
sinkhorn_iter = sinkhorn_iter,
|
|
eps = eps,
|
|
)
|
|
|
|
|
|
def fused_indexer_q(q: torch.Tensor,
|
|
w_q: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
position_id: torch.Tensor,
|
|
output: Optional[torch.Tensor] = None,
|
|
hadamard_matrix: Optional[torch.Tensor] = None,
|
|
w_q_scale: Optional[torch.Tensor] = None,
|
|
output_quant_mode: str = 'none',
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
interleaved: bool = True,
|
|
rope_at_front: bool = True):
|
|
return tmo.fused_indexer_q(
|
|
q = q,
|
|
w_q = w_q,
|
|
sin = sin,
|
|
cos = cos,
|
|
position_id = position_id,
|
|
output = output,
|
|
hadamard_matrix = hadamard_matrix,
|
|
w_q_scale = w_q_scale,
|
|
output_quant_mode = output_quant_mode,
|
|
output_scale = output_scale,
|
|
interleaved = interleaved,
|
|
rope_at_front = rope_at_front)
|
|
|
|
def fused_mla_q_v2(
|
|
input_q: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
smooth_quant_scale: Optional[torch.Tensor],
|
|
weight_b: torch.Tensor,
|
|
weight_b_scale: Optional[torch.Tensor],
|
|
sin: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
position_id: torch.Tensor,
|
|
output: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-6,
|
|
interleaved: bool = True,
|
|
store_norm: bool = False,
|
|
output_norm: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
|
"""
|
|
This function applies MLA (Multi-head Latent Attention) v2 Query (Q) preprocessing.
|
|
The fusion logic includes: RMSNorm -> Quant(Optional) -> MatMul -> RMSNorm -> RoPE.
|
|
|
|
Math:
|
|
qr = rmsnorm(input_q, gamma, eps)
|
|
if quant:
|
|
qr, q_scale = per_token_quant(norm_out, smooth_quant_scale)
|
|
q = matmul(qr, q_scale, weight_b, weight_b_scale)
|
|
q = q.reshape(batch, seq, n_local_heads, head_dim)
|
|
q = rsqrt(q.square().mean(-1, keepdim=True) + eps)
|
|
out = apply_rotary_embedding(q, sin, cos, position_id, interleaved)
|
|
|
|
Args:
|
|
input_q (torch.Tensor):
|
|
The input latent query tensor. Shape is (batch, seq, q_lora_rank).
|
|
gamma (torch.Tensor):
|
|
The scaling parameter for the initial RMSNorm. Shape is (q_lora_rank).
|
|
smooth_quant_scale (Optional[torch.Tensor]):
|
|
Scale tensor for SmoothQuant migration. Can be None. Shape is (q_lora_rank).
|
|
weight_b (torch.Tensor):
|
|
The Q-projection weight tensor. Shape is (n_local_heads, head_dim, q_lora_rank).
|
|
weight_b_scale (Optional[torch.Tensor]):
|
|
The per-channel quantization scales for weight_b. Shape is (n_local_heads, head_dim).
|
|
sin (torch.Tensor):
|
|
Rotary embedding sine table. Shape is (max_rotary_seq_len, rotary_head_dim).
|
|
cos (torch.Tensor):
|
|
Rotary embedding cosine table. Shape is (max_rotary_seq_len, rotary_head_dim).
|
|
position_id (torch.Tensor):
|
|
Indices for the RoPE tables. Shape is (batch,).
|
|
output (Optional[torch.Tensor]):
|
|
Optional output tensor for the final processed Q. Shape is (batch, seq, n_local_heads, head_dim).
|
|
eps (float):
|
|
Small constant for RMSNorm numerical stability. Default: 1e-6.
|
|
interleaved (bool):
|
|
If True, apply interleaved rotary embedding, otherwise folded. Default: True.
|
|
store_norm (bool):
|
|
If True, the intermediate RMSNorm result (pre-MatMul) will be returned. Default: False.
|
|
output_norm (Optional[torch.Tensor]):
|
|
Optional tensor to store the intermediate RMSNorm result. Shape: (batch, seq, q_lora_rank).
|
|
|
|
Type:
|
|
input_q, gamma, sin, cos: bfloat16.
|
|
weight_b: int8, same as input_q.
|
|
weight_b_scale, smooth_quant_scale: float32.
|
|
position_id: int32.
|
|
output: same as input_q.
|
|
|
|
Return:
|
|
Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
|
- If store_norm=False: output
|
|
- If store_norm=True: (..., output_norm) is appended to the return.
|
|
"""
|
|
return tmo.fused_mla_q_v2(
|
|
input_q=input_q,
|
|
gamma=gamma,
|
|
smooth_quant_scale=smooth_quant_scale,
|
|
weight_b=weight_b,
|
|
weight_b_scale=weight_b_scale,
|
|
sin=sin,
|
|
cos=cos,
|
|
position_id=position_id,
|
|
output=output,
|
|
eps=eps,
|
|
interleaved=interleaved,
|
|
store_norm=store_norm,
|
|
output_norm=output_norm,
|
|
)
|
|
|
|
def update_compressor_states(
|
|
kv_state, # (max_batch, (overlap+1)*ratio + K, dim)
|
|
score_state, # (max_batch, (overlap+1)*ratio + K, dim)
|
|
accept_tokens: torch.Tensor, # (bsz,)
|
|
batch_to_kv_state: torch.Tensor, # (bsz,)
|
|
positions: torch.Tensor, # (bsz,)
|
|
cu_query_len: torch.Tensor, # (bsz+1,)
|
|
overlap: bool,
|
|
K: int
|
|
):
|
|
bsz = batch_to_kv_state.numel()
|
|
ratio = (kv_state.size(1) - K) // (overlap + 1)
|
|
start_positions = positions[cu_query_len[:bsz]]
|
|
end_positions = start_positions + accept_tokens
|
|
|
|
for i in range(bsz):
|
|
start_pos = start_positions[i]
|
|
end_pos = end_positions[i]
|
|
# Skip if sequence len does not exceed coff * ratio.
|
|
if (overlap and end_pos < 2 * ratio) or (not overlap and end_pos < ratio):
|
|
continue
|
|
|
|
# Skip if compression condition does not meets.
|
|
if (start_pos // ratio) == (end_pos // ratio) and start_pos % ratio != 0:
|
|
continue
|
|
|
|
state_idx = batch_to_kv_state[i]
|
|
|
|
if overlap:
|
|
length = end_pos - start_pos + start_pos % ratio
|
|
else:
|
|
length = end_pos % ratio
|
|
start = ratio
|
|
end = start + length
|
|
|
|
if length == 0:
|
|
continue
|
|
|
|
kv_state[state_idx, :length] = kv_state[state_idx, start:end].clone()
|
|
score_state[state_idx, :length] = score_state[state_idx, start:end].clone()
|