Files
enginex-vastai-va16-vllm/torch_vacc/vacc/custom_ops.py
2026-04-02 04:55:00 +00:00

2819 lines
76 KiB
Python

from __future__ import annotations
from typing import List, Optional, Tuple, Union
from enum import Enum
import torch
from torch import Generator
from torch._C._distributed_c10d import ReduceOp
from torch_vacc._vacc_libs import _torch_vacc
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, output=None
) -> torch.Tensor:
r"""Root mean square(RMS) normalization of inputs over last dimension.
vacc fused kernal for rms_norm, only for three dimensions of input and sigle dimemsion weight
Args:
input: input tensor with shape (batch, seq_len, hidden size)
weight: weight tensor with shape (hidden size,)
eps (float): small value to avoid division by zero. Default: 1e-6
Returns:
Tensor: tensor after applying rms_norm
"""
# assert input.dim() == 3, "rms_norm only support the input with dim=3"
if input.device.type == "vacc":
return torch.ops.vacc.rms_norm_func(input, weight, eps, output)
input_dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
rsigma = torch.rsqrt(variance + eps)
normalized_states = input * rsigma
return weight * normalized_states.to(input_dtype)
rms_norm.apply = rms_norm
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=x1.ndim - 1)
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def RotaryPosEmbedding(
q: torch.Tensor,
k: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
offset: int = 0,
mode: str = "neox",
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs: Apply rotary positional embedding to input tensor q/k in `sbhd` or s(b*h)d format, where
s: sequence length
b: batch size
h: head num
d: dim of each head
Note: if cos/sin=None, vacc do RoPE in classical mode, generating cos/sin with arg 'base=10000'
Args:
q (Tensor): Input tensor T is of shape [s, b*h, d]
k (Tensor): Input tensor T is of shape [s, b*h, d]
where s is sequence lenth; b is batch size; h is heads; d is head dims
cos optional(Tensor): Cached cosine of the rotary positional embedding tensor
with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2]
sin optional(Tensor): Cached sine of the rotary positional embedding tensor.
with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2]
offset (int): offset of position at cos/sin cache, default=0
mode (str): 'nexo': rotate half, 'gptj': rotate every two
Returns:
Tuple[Tensor, Tensor]: The tuple of q/k tensor after applying RoPE
"""
assert (
q.dim() == 3 and k.dim() == 3
), f"the dim of q/k should be 3 but get q:{q.dim()}, k:{k.dim()}"
assert mode in ["neox", "gptj"], "only support rope mode 'neox' or 'gptj'"
assert q.dtype == k.dtype, "the dtype should be same"
assert (cos == None and sin == None) or (
isinstance(cos, torch.Tensor) and isinstance(sin, torch.Tensor)
)
assert q.device.type == "vacc" and k.device.type == "vacc"
mode_ = 0 if mode == "gptj" else 1
if cos is not None and cos.size(0) != q.size(0):
cos = cos[offset: q.shape[0] + offset, ...]
if sin is not None and sin.size(0) != q.size(0):
sin = sin[offset: q.shape[0] + offset, ...]
# repeat last dim as same size with input tensor
# if cos is not None and cos.numel() != 0:
# assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1)
# assert cos.dim() == 2 or cos.dim() == 3
# if cos.size(-1) * 2 == q.size(-1):
# if mode_ == 0:
# cos = cos.repeat_interleave(2, dim=-1)
# sin = sin.repeat_interleave(2, dim=-1)
# else:
# cos = torch.cat([cos, cos], dim=-1)
# sin = torch.cat([sin, sin], dim=-1)
if cos is not None and cos.numel() != 0:
assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1)
assert cos.dim() == 2 or cos.dim() == 3
if cos.dim() == 2:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
return torch.ops.vacc.RotaryPosEmbedding_func(q, k, cos, sin, offset, mode_)
RotaryPosEmbedding.apply = RotaryPosEmbedding
def scaled_dot_product_attention(
# same with torch define
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
dropout_p: float = 0.5,
is_causal: bool = False,
scale: float = None,
# extend
is_train: bool = True,
recompute: bool = False,
flash_attention: bool = False,
sm_scale: float = -1,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d],
kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d].
support float16, bfloat16, float32
sq: sequence length of query
sk: sequence length of key, value
b: batch size
h: head num
d: dim of each head
Args:
query (Tensor): Input tensor T is of shape [sq, b*h, d]
key (Tensor): Input tensor T is of shape [sk, b*h, d]
value (Tensor): Input tensor T is of shape [sk, b*h, d]
attn_mask (Tensor): masked bool tensor of shape [1, sq, sk] or [sq, sk]
dropout_p (float): the probability of dropout
is_causal (bool): accelerate compute when mask is causal type
scale (float): not use, default is 1/sqrt(dim)
is_train (bool): train mode or eval mode
recompute (bool): whether to recompute for reducing memory usage, is valid
when is_train=True
flash_attention (bool): using flash attention, that cat support large sequence
Returns:
Tensor: the tensor after self attention
"""
assert (
query.dtype == key.dtype and key.dtype == value.dtype
), "types of qkv should be same"
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
if attn_mask == None:
attn_mask = torch.Tensor()
out = torch.ops.vacc.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_train,
recompute,
is_causal,
flash_attention,
sm_scale,
)
return out[0]
scaled_dot_product_attention.apply = scaled_dot_product_attention
def swiglu(
x: torch.Tensor,
) -> torch.Tensor:
r"""Perferms:
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
Args:
x: Input tensor, support float/float16/bfloat16, 3 dims
Return:
Tensor: the out of swiglu
"""
return torch.ops.vacc.swiglu(x)
swiglu.apply = swiglu
def scaled_dot_product_attention_cp_forward(
# same with torch define
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
dropout_p: float = 0.5,
is_causal: bool = False,
# extend
is_train: bool = True,
):
r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d],
kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d].
sq: sequence length of query
sk: sequence length of key, value
b: batch size
h: head num
d: dim of each head
Args:
query (Tensor): Input fp16/bfp16 tensor T is of shape [sq, b*h, d]
key (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d]
value (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d]
attn_mask (Tensor): masked bool tensor of shape [1, sq, sk]
dropout_p (float): the probability of dropout
is_causal (bool): accelerate compute when mask is causal type
scale (float): not use, default is 1/sqrt(dim)
is_train (bool): train mode or eval mode
Returns: list of tensor, size=4
[attention result,
max of QK^t with shape of (b*h, sq, d)],
sum of exp(QK^t) with shape of (b*h, sq, d),
seed(used in backward)]
"""
assert (
query.dtype == key.dtype and key.dtype == value.dtype
), "types of qkv should be same"
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
assert (
is_causal or attn_mask != None
), "attn_mask should be valid when is_causal is False"
if attn_mask is None:
attn_mask = torch.Tensor()
else:
assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size(
0
), "attn_mask size should be (..., sq, sk)"
out = _torch_vacc.scaled_dot_product_attention_cp_forward(
query,
key,
value,
attn_mask,
dropout_p,
is_train,
is_causal,
)
return out
def scaled_dot_product_attention_cp_backward(
grad_output: torch.Tensor,
attn_out: torch.Tensor,
max_of_row: torch.Tensor,
sum_of_row: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
seed: torch.Tensor,
dropout_p: float = 0.5,
is_causal: bool = False,
is_train: bool = True,
):
assert (
query.dtype == key.dtype and key.dtype == value.dtype
), "types of qkv should be same"
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
assert (
is_causal or attn_mask != None
), "attn_mask should be valid when is_causal is False"
if attn_mask == None:
attn_mask = torch.Tensor()
else:
assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size(
0
), "attn_mask size should be (..., sq, sk)"
out = _torch_vacc.scaled_dot_product_attention_cp_backward(
grad_output,
attn_out,
max_of_row,
sum_of_row,
query,
key,
value,
attn_mask,
dropout_p,
is_train,
is_causal,
seed,
)
return out
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
seq_len: torch.Tensor,
sm_scale: float = -1,
out: Optional[torch.Tensor] = None,
):
r"""Performs: Apply attention operation for q tensor with fixed shape[batch, b*h, d],
key_cache/value_cache with fixed shape[nb, bs, h, d], while the output is [batch, h, d].
batch: token batch
nb: num_blocks
bs: block_size
h: heads num
d: dim of each head
Args:
query (Tensor): Input fp16/bfp16 tensor T is of shape [batch, h, d]
key_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d]
value_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d]
block_table (Tensor): k/v map of cache
seq_len (Tensor): sequence lenth of each batch
out (optional Tensor): output tensor if given, otherwise return a new tensor
Returns:
Tensor: the tensor after page attention
"""
return _torch_vacc.paged_attention(
query, key_cache, value_cache, block_table, seq_len, out, sm_scale
)
def reshape_and_cache_attention(
src: torch.Tensor,
cached: torch.Tensor,
block_mapping: torch.Tensor,
):
torch.ops.vacc.reshape_and_cache_attention(src, cached, block_mapping)
def concat_and_cache_attention(
src: torch.Tensor,
src1: torch.Tensor,
cached: torch.Tensor,
block_mapping: torch.Tensor,
):
_torch_vacc.concat_and_cache_attention(src, src1, cached, block_mapping)
def w8a8_block_fp8_matmul(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
output: Optional[torch.Tensor] = None,
**kwargs,
):
return _torch_vacc.w8a8_block_fp8_matmul(
input,
weight,
None,
weight_scale,
block_size,
output,
)
def w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
output: Optional[torch.Tensor] = None,
**kwargs,
):
return _torch_vacc.w8a8_block_fp8_matmul(
input, weight.T, None, weight_scale.T, [block_size[1], block_size[0]], output
)
def moe_expert_token_group_reassign(
topk_idx: torch.Tensor,
topk_val: torch.Tensor,
expert_num_: int,
gp_size_: int = 16,
gp_num_align_: int = 4,
):
return _torch_vacc.moe_expert_token_group_reassign(
topk_idx, topk_val, expert_num_, gp_size_, gp_num_align_
)
def fused_experts(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
decode_with_batch: bool = False,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
warning_message = "[fused_experts]:vacc only support fp8 weights now"
assert a13_scale is None and a2_scale is None, warning_message
assert use_fp8_w8a8, f"{warning_message}, but use_fp8_w8a8 is {use_fp8_w8a8}"
assert (
w13_scale is not None and w2_scale is not None
), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}"
assert (
block_shape is not None
), f"{warning_message}, but block_shape is {block_shape}"
# assert (
# not decode_with_batch or hidden_states.size(0) <= 4
# ), "[fused_experts]:vacc only support batch <= 4 when decode"
if hidden_states.device.type == "vacc":
# topk weights dtype should be same with hidden_states
topk_weights = topk_weights.to(hidden_states.dtype)
# vacc device use int32 for experts_id
topk_ids = topk_ids.to(torch.int32)
hidden_dims, inter_dims = w13_weight.shape[1], w13_weight.shape[2]
hidden_blocks, inter_blocks = w13_scale.shape[1], w13_scale.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
# assert (
# block_size0 == block_size1
# ), "quant block shape now support size0 == size1"
return _torch_vacc.fused_experts(
hidden_states,
w13_weight,
w2_weight,
topk_weights,
topk_ids,
use_fp8_w8a8,
w13_scale,
w2_scale,
a13_scale,
a2_scale,
[block_size0, block_size1],
decode_with_batch,
output_opt,
)
from .custom_ops_cpu import fused_experts as fused_experts_default_method
return fused_experts_default_method(
hidden_states,
w13_weight,
w2_weight,
topk_weights,
topk_ids,
use_fp8_w8a8,
w13_scale,
w2_scale,
a13_scale,
a2_scale,
block_shape,
decode_with_batch,
)
# NOTE: w13, w2 using linear format
def fused_mlp_fp8(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_w13: Optional[List[int]] = None,
block_shape_w2: Optional[List[int]] = None,
output: Optional[torch.Tensor] = None,
):
assert a13_scale is None
assert a2_scale is None
assert w13_scale is not None
assert w2_scale is not None
assert block_shape_w13 is not None
assert block_shape_w2 is not None
if hidden_states.device.type == "vacc":
return _torch_vacc.fused_mlp(
hidden_states,
w13_weight,
w2_weight,
use_fp8_w8a8,
w13_scale,
w2_scale,
a13_scale,
a2_scale,
block_shape_w13,
block_shape_w2,
output,
)
from .custom_ops_cpu import fused_mlp_mm_fp8 as fused_mlp_mm_fp8_default_method
return fused_mlp_mm_fp8_default_method(
hidden_states,
w13_weight.T,
w2_weight.T,
use_fp8_w8a8,
w13_scale.T,
w2_scale.T,
a13_scale,
a2_scale,
list(block_shape_w13)[::-1],
list(block_shape_w2)[::-1],
)
def fused_mlp_mm_fp8(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_w13: Optional[List[int]] = None,
block_shape_w2: Optional[List[int]] = None,
output: Optional[torch.Tensor] = None,
):
return fused_mlp_fp8(
hidden_states,
w13_weight.T,
w2_weight.T,
use_fp8_w8a8,
w13_scale.T,
w2_scale.T,
a13_scale,
a2_scale,
list(block_shape_w13)[::-1],
list(block_shape_w2)[::-1],
output,
)
def fused_moe_preprocess(
gating_output,
bias,
num_expert_group=8,
num_limited_group=4,
):
return _torch_vacc.fused_moe_preprocess(
gating_output, bias, num_expert_group, num_limited_group
)
def fused_residual_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
epsilon: float = 1e-6,
output: Optional[torch.Tensor] = None,
residual_out: Optional[torch.Tensor] = None,
):
# out = _torch_vacc.fused_residual_rmsnorm(input, weight, residual, epsilon, inplace)
# if len(out) == 1:
# return out[0]
# return out
# TODO: VNNL support optional residual
input = input.contiguous()
weight = weight.contiguous()
if residual is None:
return rms_norm(input, weight, epsilon)
else:
return torch.ops.vacc.fused_residual_rmsnorm(
input,
weight,
residual,
epsilon,
output,
residual_out,
)
def parallel_embedding(
input: torch.Tensor,
weight: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
output: Optional[torch.Tensor] = None,
):
return _torch_vacc.parallel_embedding(
input,
weight,
org_vocab_start_index,
org_vocab_end_index,
num_org_vocab_padding,
added_vocab_start_index,
added_vocab_end_index,
output,
)
def fused_mla(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
W_Q: torch.Tensor,
W_UK: torch.Tensor,
W_QR: torch.Tensor,
kv_a_proj_weight_scale_inv: torch.Tensor,
kv_a_proj_weight: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
W_UV: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
o_proj_weight: torch.Tensor,
q_a_proj_blocksize: Tuple[int] | List[int],
kv_a_proj_blocksize: Tuple[int] | List[int],
o_proj_blocksize: Tuple[int] | List[int],
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
):
# TODO: CHECK
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
W_Q,
W_UK,
W_QR,
kv_a_proj_weight_scale_inv,
kv_a_proj_weight,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
W_UV,
o_proj_weight_scale_inv,
o_proj_weight,
q_a_proj_blocksize,
kv_a_proj_blocksize,
o_proj_blocksize,
seq_lens,
sm_scale,
head_num,
)
if out_single:
return out[0]
return out
def fused_mla_v2(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
w_q: torch.Tensor,
w_q_scale: torch.Tensor,
w_uk: torch.Tensor,
w_uk_scale: torch.Tensor,
w_qr: torch.Tensor,
w_qr_scale: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
w_uv: torch.Tensor,
w_uv_scale: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool=False,
):
# TODO: CHECk
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla_v2(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
)
if out_single:
return out[0]
return out
def fused_mla_prefill_stage0(
hidden_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
qkv_a_proj_weight: torch.Tensor,
qkv_a_proj_weight_scale_inv: torch.Tensor,
):
# TODO: CHECk
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla_prefill_stage0(
hidden_states,
residual,
hidden_states_norm_weight,
qkv_a_proj_weight,
qkv_a_proj_weight_scale_inv,
)
if out_single:
return out[0]
return out
def fused_mla_prefill_stage1(
qkv_a: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_proj_weight: torch.Tensor,
q_proj_weight_scale_inv: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
kv_b_proj_weight: torch.Tensor,
kv_b_proj_weight_scale_inv: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
seq_lens_num: List[int],
sm_scale: float,
num_head: int,
mla_out_tensor: Optional[torch.Tensor] = None
):
out = _torch_vacc.fused_mla_prefill_stage1(
qkv_a,
q_a_layernorm_weight,
q_proj_weight,
q_proj_weight_scale_inv,
kv_a_layernorm_weight,
kv_b_proj_weight,
kv_b_proj_weight_scale_inv,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens_num,
sm_scale,
num_head,
mla_out_tensor
)
return out
def fused_mla_allreduce(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
W_Q: torch.Tensor,
W_UK: torch.Tensor,
W_QR: torch.Tensor,
kv_a_proj_weight_scale_inv: torch.Tensor,
kv_a_proj_weight: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
W_UV: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
o_proj_weight: torch.Tensor,
q_a_proj_blocksize: Tuple[int] | List[int],
kv_a_proj_blocksize: Tuple[int] | List[int],
o_proj_blocksize: Tuple[int] | List[int],
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
red_op_type: int,
world_size: int,
rank: int,
root_rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert red_op_type == 0, "all_reduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla_allreduce(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
W_Q,
W_UK,
W_QR,
kv_a_proj_weight_scale_inv,
kv_a_proj_weight,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
W_UV,
o_proj_weight_scale_inv,
o_proj_weight,
q_a_proj_blocksize,
kv_a_proj_blocksize,
o_proj_blocksize,
seq_lens,
sm_scale,
head_num,
red_op_type,
world_size,
rank,
root_rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def fused_mla_allreduce_v2(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
w_q: torch.Tensor,
w_q_scale: torch.Tensor,
w_uk: torch.Tensor,
w_uk_scale: torch.Tensor,
w_qr: torch.Tensor,
w_qr_scale: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
w_uv: torch.Tensor,
w_uv_scale: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert world_size > 0, "fused_mla_allreduce_v2 only support world_size > 0"
assert rank >= 0, "fused_mla_allreduce_v2 only support rank >= 0"
if not dev_info:
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla_allreduce_v2(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
world_size,
rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def fused_mla_prefill_stage0_allreduce(
hidden_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
qkv_a_proj_weight: torch.Tensor,
qkv_a_proj_weight_scale_inv: torch.Tensor,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECk
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fused_mla_prefill_stage0_allreduce(
hidden_states,
residual,
hidden_states_norm_weight,
qkv_a_proj_weight,
qkv_a_proj_weight_scale_inv,
world_size,
rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def all_reduce(
input: torch.Tensor,
rank: int,
world_size: int,
group_id: int,
dev_info: List[int],
red_op_type: int = 0,
):
assert input.device.type == "vacc", "all_reduce only support VACC"
assert red_op_type == 0, "all_reduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.all_reduce(
input, rank, world_size, 0, group_id, dev_info, red_op_type
)
def all_gather(
input: torch.Tensor,
rank: int,
world_size: int,
group_id: int,
dev_info: List[int],
):
assert input.device.type == "vacc", "all_gather only support VACC"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.all_gather(input, rank, world_size, 0, group_id, dev_info, 0)
def broadcast(
input: torch.Tensor,
rank: int,
world_size: int,
root_rank: int,
group_id: int,
dev_info: List[int],
):
assert input.device.type == "vacc", "broadcast only support VACC"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.all_gather(
input, rank, world_size, root_rank, group_id, dev_info, 1
)
def fused_mlp_moe_with_rmsnorm(
hidden_states: torch.Tensor,
rms_residual: torch.Tensor,
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
moe_weight_13: torch.Tensor,
moe_weight_2: torch.Tensor,
moe_weight_scale_13: torch.Tensor,
moe_weight_scale_2: torch.Tensor,
mm_weight: torch.Tensor,
moe_bias: torch.Tensor,
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
moe_block_size_w13: List[int] | Tuple[int],
moe_block_size_w2: List[int] | Tuple[int],
):
return _torch_vacc.fused_mlp_moe_with_rmsnorm(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
moe_weight_13,
moe_weight_2,
moe_weight_scale_13,
moe_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
moe_block_size_w13,
moe_block_size_w2,
)
def fused_mlp_with_rmsnorm(
hidden_states: torch.Tensor,
rms_residual: torch.Tensor,
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
):
return _torch_vacc.fused_mlp_with_rmsnorm(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mlp_block_size_w13,
mlp_block_size_w2,
)
def fuse_moe_decode_v2_allreduce(
hidden_states: torch.Tensor,
rms_residual: torch.Tensor,
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
moe_weight_13: torch.Tensor,
moe_weight_2: torch.Tensor,
moe_weight_scale_13: torch.Tensor,
moe_weight_scale_2: torch.Tensor,
mm_weight: torch.Tensor,
moe_bias: torch.Tensor,
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
moe_block_size_w13: List[int] | Tuple[int],
moe_block_size_w2: List[int] | Tuple[int],
red_op_type: int,
world_size: int,
rank: int,
root_rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
assert (
hidden_states.device.type == "vacc"
), "fuse_moe_decode_v2_allreduce only support VACC"
assert red_op_type == 0, "fuse_moe_decode_v2_allreduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_moe_decode_v2_allreduce(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
moe_weight_13,
moe_weight_2,
moe_weight_scale_13,
moe_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
moe_block_size_w13,
moe_block_size_w2,
red_op_type,
world_size,
rank,
root_rank,
group_id,
dev_info,
)
def topk_topp(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
):
return _torch_vacc.topk_topp(logits, p, k)
def fused_mlp_allreduce(
hidden_states: torch.Tensor,
rms_residual: torch.Tensor,
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
red_op_type: int,
world_size: int,
rank: int,
root_rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
assert hidden_states.device.type == "vacc", "fused_mlp_allreduce only support VACC"
assert red_op_type == 0, "fused_mlp_allreduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fused_mlp_allreduce(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mlp_block_size_w13,
mlp_block_size_w2,
red_op_type,
world_size,
rank,
root_rank,
group_id,
dev_info,
)
def mla_matmul_scale(
input: torch.Tensor, weight: torch.Tensor, scale: float, align_seq_len: int = 1024
):
return _torch_vacc.mla_matmul_scale(
input,
weight,
scale,
align_seq_len,
)
def mla_matmul(
input: torch.Tensor,
weight: torch.Tensor,
):
return _torch_vacc.mla_matmul(
input,
weight,
)
def ds3_sampler(
src, p, k, temperatures, exponential_enable, generator: Optional[Generator] = None
)-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
return _torch_vacc.ds3_sampler(
src, p, k, temperatures, exponential_enable, generator
)
def sampler_v1(
src, p, k, temperatures, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]] = {}
)->Tuple[torch.Tensor, torch.Tensor]:
if not isinstance(generators, dict):
raise TypeError(f"generator must be a dictionary, got {type(generators)}")
return _torch_vacc.sampler_v1(
src, p, k, temperatures, all_greedy, all_random, generators
)
def apply_penalties(
src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens,
frequency_penalties, presence_penalties, is_first_calculation
)->Tuple[torch.Tensor]:
return _torch_vacc.apply_penalties(src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens, frequency_penalties, presence_penalties, is_first_calculation)
def rejection_sampler(
target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator: Optional[Generator] = None
):
return _torch_vacc.rejection_sampler(
target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator
)
def rejection_sampler_update_hidden_states(
hidden_states, accepted_index
):
return _torch_vacc.rejection_sampler_update_hidden_states(
hidden_states, accepted_index
)
def rejection_sampler_v1(
target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]]
):
return _torch_vacc.rejection_sampler_v1(
target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators
)
def fused_matmul_allgather(
input: torch.Tensor,
mat2: torch.Tensor,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] = None,
) -> torch.Tensor:
assert input.device.type == "vacc", "fused_matmul_allgather only support VACC"
assert mat2.device.type == "vacc", "fused_matmul_allgather only support VACC"
assert 2 == input.ndim, "fused_matmul_allgather: 'input' must be 2D tensor"
assert 2 == mat2.ndim, "fused_matmul_allgather: 'mat2' must be 2D tensor"
assert (
input.shape[-1] == mat2.shape[0]
), "fused_matmul_allgather: dim1 of 'input' must be equal to dim0 of 'mat2'"
assert world_size > 0, "world_size must be greater than 0"
if dev_info is None or 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fused_matmul_allgather(
input,
mat2,
world_size,
rank,
group_id,
dev_info,
)
def fuse_moe_prefill_stage0(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
rms_hidden_state_opt: Optional[torch.Tensor] = None,
mlp_hidden_state_opt: Optional[torch.Tensor] = None,
topk_ids_opt: Optional[torch.Tensor] = None,
topk_weight_opt: Optional[torch.Tensor] = None,
):
return _torch_vacc.fuse_moe_prefill_stage0(
hidden_states,
rms_residual,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
rms_hidden_state_opt,
mlp_hidden_state_opt,
topk_ids_opt,
topk_weight_opt,
)
def fuse_mla_mlp_v2_allreduce_decode(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
w_q: torch.Tensor,
w_q_scale: torch.Tensor,
w_uk: torch.Tensor,
w_uk_scale: torch.Tensor,
w_qr: torch.Tensor,
w_qr_scale: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
w_uv: torch.Tensor,
w_uv_scale: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# mlp weight
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
# mlp params
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_mlp_v2_allreduce_decode only support VACC"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mlp_block_size_w13,
mlp_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
# if out_single:
# return out[0]
return out
def fuse_mla_moe_v2_allreduce_decode(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
hidden_states_norm_weight: torch.Tensor,
q_a_proj_weight: torch.Tensor,
q_a_proj_weight_scale_inv: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
w_q: torch.Tensor,
w_q_scale: torch.Tensor,
w_uk: torch.Tensor,
w_uk_scale: torch.Tensor,
w_qr: torch.Tensor,
w_qr_scale: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
w_uv: torch.Tensor,
w_uv_scale: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# moe weight
rms_weight: torch.Tensor,
mlp_weight_13: torch.Tensor,
mlp_weight_2: torch.Tensor,
mlp_weight_scale_13: torch.Tensor,
mlp_weight_scale_2: torch.Tensor,
moe_weight_13: torch.Tensor,
moe_weight_2: torch.Tensor,
moe_weight_scale_13: torch.Tensor,
moe_weight_scale_2: torch.Tensor,
mm_weight: torch.Tensor,
moe_bias: torch.Tensor,
# moe params
mlp_block_size_w13: Tuple[int] | List[int],
mlp_block_size_w2: Tuple[int] | List[int],
moe_block_size_w13: Tuple[int] | List[int],
moe_block_size_w2: Tuple[int] | List[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_moe_v2_allreduce_decode only support VACC"
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
moe_weight_13,
moe_weight_2,
moe_weight_scale_13,
moe_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
moe_block_size_w13,
moe_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
# if out_single:
# return out[0]
return out
# ! register fake function for cpp imple custom op
@torch.library.register_fake("vacc::RotaryPosEmbedding_func")
def _(
q: torch.Tensor,
k: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
offset: int = 0,
mode: str = "neox",
):
return [torch.empty_like(q), torch.empty_like(k)]
@torch.library.register_fake("vacc::reshape_and_cache_attention")
def _(
src: torch.Tensor,
cached: torch.Tensor,
block_mapping: torch.Tensor,
):
pass
@torch.library.register_fake("vacc::scaled_dot_product_attention")
def _(
query,
key,
value,
attn_mask,
dropout_p,
is_train,
recompute,
is_causal,
flash_attention,
sm_scale
):
return [torch.empty(size=(query.size()[0], query.size()[1], key.size()[2]), device=query.device, dtype=query.dtype)]
@torch.library.register_fake("vacc::rms_norm_func")
def _(
input: torch.Tensor, weight: torch.Tensor, eps: float, output: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.empty_like(input)
@torch.library.register_fake("vacc::fused_residual_rmsnorm")
def _(
input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
epsilon: float = 1e-6,
output: Optional[torch.Tensor] = None,
residual_out: Optional[torch.Tensor] = None,
):
return [torch.empty_like(input), torch.empty_like(input)]
@torch.library.register_fake("vacc::swiglu")
def _(
self: torch.Tensor
):
shape = list(self.shape)
shape[-1] = shape[-1] // 2
return torch.empty(size=shape, dtype=self.dtype, device=self.device)
def fuse_mla_mlp_v2_allreduce_decode_layers(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
hidden_states_norm_weight: List[torch.Tensor],
q_a_proj_weight: List[torch.Tensor],
q_a_proj_weight_scale_inv: List[torch.Tensor],
q_a_layernorm_weight: List[torch.Tensor],
w_q: List[torch.Tensor],
w_q_scale: List[torch.Tensor],
w_uk: List[torch.Tensor],
w_uk_scale: List[torch.Tensor],
w_qr: List[torch.Tensor],
w_qr_scale: List[torch.Tensor],
kv_a_layernorm_weight: List[torch.Tensor],
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: List[torch.Tensor],
block_tables: torch.Tensor,
block_group_size: int,
w_uv: List[torch.Tensor],
w_uv_scale: List[torch.Tensor],
o_proj_weight: List[torch.Tensor],
o_proj_weight_scale_inv: List[torch.Tensor],
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# mlp weight
rms_weight: List[torch.Tensor],
mlp_weight_13: List[torch.Tensor],
mlp_weight_2: List[torch.Tensor],
mlp_weight_scale_13: List[torch.Tensor],
mlp_weight_scale_2: List[torch.Tensor],
# mlp params
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_mlp_v2_allreduce_decode_layers only support VACC"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
if residual is None:
residual = torch.Tensor()
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
mlp_block_size_w13,
mlp_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
return out
def fuse_mla_mlp_v2_allreduce_decode_layers_v2(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: List[torch.Tensor],
block_tables: torch.Tensor,
block_group_size: int,
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# mlp weight
# mlp params
mlp_block_size_w13: List[int] | Tuple[int],
mlp_block_size_w2: List[int] | Tuple[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_mlp_v2_allreduce_decode_layers_v2 only support VACC"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
if residual is None:
residual = torch.Tensor()
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers(
hidden_states,
residual,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
seq_lens,
sm_scale,
head_num,
flash_attention,
mlp_block_size_w13,
mlp_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
return out
def fuse_mla_moe_v2_allreduce_decode_layers(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
hidden_states_norm_weight: List[torch.Tensor],
q_a_proj_weight: List[torch.Tensor],
q_a_proj_weight_scale_inv: List[torch.Tensor],
q_a_layernorm_weight: List[torch.Tensor],
w_q: List[torch.Tensor],
w_q_scale: List[torch.Tensor],
w_uk: List[torch.Tensor],
w_uk_scale: List[torch.Tensor],
w_qr: List[torch.Tensor],
w_qr_scale: List[torch.Tensor],
kv_a_layernorm_weight: List[torch.Tensor],
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: List[torch.Tensor],
block_tables: torch.Tensor,
block_group_size: int,
w_uv: List[torch.Tensor],
w_uv_scale: List[torch.Tensor],
o_proj_weight: List[torch.Tensor],
o_proj_weight_scale_inv: List[torch.Tensor],
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# moe weight
rms_weight: List[torch.Tensor],
mlp_weight_13: List[torch.Tensor],
mlp_weight_2: List[torch.Tensor],
mlp_weight_scale_13: List[torch.Tensor],
mlp_weight_scale_2: List[torch.Tensor],
moe_weight_13: List[torch.Tensor],
moe_weight_2: List[torch.Tensor],
moe_weight_scale_13: List[torch.Tensor],
moe_weight_scale_2: List[torch.Tensor],
mm_weight: List[torch.Tensor],
moe_bias: List[torch.Tensor],
# moe params
mlp_block_size_w13: Tuple[int] | List[int],
mlp_block_size_w2: Tuple[int] | List[int],
moe_block_size_w13: Tuple[int] | List[int],
moe_block_size_w2: Tuple[int] | List[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_moe_v2_allreduce_decode_layers only support VACC"
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers(
hidden_states,
residual,
hidden_states_norm_weight,
q_a_proj_weight,
q_a_proj_weight_scale_inv,
q_a_layernorm_weight,
w_q,
w_q_scale,
w_uk,
w_uk_scale,
w_qr,
w_qr_scale,
kv_a_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
w_uv,
w_uv_scale,
o_proj_weight,
o_proj_weight_scale_inv,
seq_lens,
sm_scale,
head_num,
flash_attention,
rms_weight,
mlp_weight_13,
mlp_weight_2,
mlp_weight_scale_13,
mlp_weight_scale_2,
moe_weight_13,
moe_weight_2,
moe_weight_scale_13,
moe_weight_scale_2,
mm_weight,
moe_bias,
mlp_block_size_w13,
mlp_block_size_w2,
moe_block_size_w13,
moe_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def fuse_mla_moe_v2_allreduce_decode_layers_v2(
# input
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
# mla weight
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: List[torch.Tensor],
block_tables: torch.Tensor,
block_group_size: int,
# mla params
seq_lens: Tuple[int] | List[int],
sm_scale: float,
head_num: int,
flash_attention: bool,
# moe params
mlp_block_size_w13: Tuple[int] | List[int],
mlp_block_size_w2: Tuple[int] | List[int],
moe_block_size_w13: Tuple[int] | List[int],
moe_block_size_w2: Tuple[int] | List[int],
# vccl info
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
# TODO: CHECK
assert (
hidden_states.device.type == "vacc"
), "fuse_mla_moe_v2_allreduce_decode_layers_v2 only support VACC"
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
# print("come to v2")
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers_v2(
hidden_states,
residual,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
seq_lens,
sm_scale,
head_num,
flash_attention,
mlp_block_size_w13,
mlp_block_size_w2,
moe_block_size_w13,
moe_block_size_w2,
world_size,
rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def fuse_mlp_qwen_int4(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
scale13: torch.Tensor,
scale2: torch.Tensor,
zero13: torch.Tensor,
zero2: torch.Tensor,
block13: List[int] | Tuple[int],
block2: List[int] | Tuple[int],
engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp
output_opt: Optional[torch.Tensor] = None
):
assert engine_mode in [0, 1, 2]
return _torch_vacc.fuse_mlp_qwen_int4(
hidden_states,
weight13,
weight2,
scale13,
scale2,
zero13,
zero2,
block13,
block2,
engine_mode,
output_opt
)
def fuse_mlp_qwen_int4_reduce(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
scale13: torch.Tensor,
scale2: torch.Tensor,
zero13: torch.Tensor,
zero2: torch.Tensor,
block13: List[int] | Tuple[int],
block2: List[int] | Tuple[int],
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
hidden_states,
weight13,
weight2,
scale13,
scale2,
zero13,
zero2,
block13,
block2,
world_size,
rank,
group_id,
dev_info,
output_opt
)
def fuse_mlp_qwen_fp8(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
scale13: torch.Tensor,
scale2: torch.Tensor,
zero13: torch.Tensor,
zero2: torch.Tensor,
block13: List[int] | Tuple[int],
block2: List[int] | Tuple[int],
engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp
output_opt: Optional[torch.Tensor] = None
):
assert engine_mode in [0, 1, 2]
return _torch_vacc.fuse_mlp_qwen_int4(
hidden_states,
weight13,
weight2,
scale13,
scale2,
zero13,
zero2,
block13,
block2,
engine_mode,
output_opt
)
def fuse_mlp_qwen_fp8_reduce(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
scale13: torch.Tensor,
scale2: torch.Tensor,
zero13: torch.Tensor,
zero2: torch.Tensor,
block13: List[int] | Tuple[int],
block2: List[int] | Tuple[int],
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
hidden_states,
weight13,
weight2,
scale13,
scale2,
zero13,
zero2,
block13,
block2,
world_size,
rank,
group_id,
dev_info,
output_opt
)
def fuse_mlp_qwen_fp16_bf16(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
output_opt: Optional[torch.Tensor] = None
):
return _torch_vacc.fuse_mlp_qwen_int4(
hidden_states,
weight13,
weight2,
torch.Tensor(),
torch.Tensor(),
torch.Tensor(),
torch.Tensor(),
(0,0),
(0,0),
0,
output_opt
)
def fuse_mlp_qwen_fp16_bf16_reduce(
hidden_states: torch.Tensor,
weight13: torch.Tensor,
weight2: torch.Tensor,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
hidden_states,
weight13,
weight2,
torch.Tensor(),
torch.Tensor(),
torch.Tensor(),
torch.Tensor(),
(0,0),
(0,0),
world_size,
rank,
group_id,
dev_info,
output_opt
)
def w4a8_block_int4_matmul(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
dequant_block: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None,
):
# TODO: CHECK
out = _torch_vacc.w4a8_block_int4_matmul(
input,
weight,
weight_scale,
dequant_block,
output_opt,
)
return out
def fuse_atten_qwen3(
hidden_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale: torch.Tensor,
qkv_proj_bias: torch.Tensor,
qkv_proj_qzeros: torch.Tensor,
q_layernorm_weight: torch.Tensor,
k_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale: torch.Tensor,
o_proj_bias: torch.Tensor,
o_proj_qzeros: torch.Tensor,
seq_lens: Tuple[int] | List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attention: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None,
res_opt: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fuse_atten_qwen3(
hidden_states,
residual,
hidden_states_norm_weight,
qkv_proj_weight,
qkv_proj_weight_scale,
none2empty(qkv_proj_bias),
none2empty(qkv_proj_qzeros),
q_layernorm_weight,
k_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
o_proj_weight,
o_proj_weight_scale,
none2empty(o_proj_bias),
none2empty(o_proj_qzeros),
seq_lens,
sm_scale,
num_attention_heads,
num_key_value_heads,
flash_attention,
is_decode,
reduce_result,
world_size,
rank,
group_id,
dev_info,
output_opt,
res_opt,
)
if out_single:
return out[0]
return out
def fuse_atten_vit(
hidden_states: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
hidden_states_norm_bias: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_bias: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
o_proj_weight: torch.Tensor,
o_proj_bias: torch.Tensor,
seq_lens: Tuple[int] | List[int],
sm_scale: float,
num_attention_heads: int,
flash_attention: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output_opt: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
out = _torch_vacc.fuse_atten_vit(
hidden_states,
hidden_states_norm_weight,
hidden_states_norm_bias,
qkv_proj_weight,
none2empty(qkv_proj_bias),
sin_cache,
cos_cache,
o_proj_weight,
none2empty(o_proj_bias),
seq_lens,
sm_scale,
num_attention_heads,
flash_attention,
reduce_result,
world_size,
rank,
group_id,
dev_info,
output_opt
)
return out
def mrope_get_sin_cos(
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
positions: torch.Tensor,
mrope_section: List[int] | Tuple[int],
mrope_interleaved: bool
):
return _torch_vacc.mrope_get_sin_cos(
cos_cache,
sin_cache,
positions,
mrope_section,
mrope_interleaved
)
# NOTE: m <= 8, dsp version
# if m > 8, call w4a8_block_int4_matmul
def w4a8_block_int4_linear(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
pack_factor: int = 8,
output_opt: Optional[torch.Tensor] = None,
):
# TODO: CHECK
out = _torch_vacc.w4a8_block_int4_linear(
input,
weight,
input_scale,
weight_scale,
bias,
pack_factor,
output_opt,
)
return out
def fuse_atten_qwen2(
history_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale_inv: torch.Tensor,
qkv_proj_bias: torch.Tensor,
qkv_proj_qzeros: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
o_proj_bias: torch.Tensor,
o_proj_qzeros: torch.Tensor,
seq_lens_num: Tuple[int] | List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attentiton: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
out_single = False
if residual is None:
out_single = True
residual = torch.Tensor()
out = _torch_vacc.fuse_atten_qwen2(
history_states,
residual,
hidden_states_norm_weight,
qkv_proj_weight,
qkv_proj_weight_scale_inv,
qkv_proj_bias,
qkv_proj_qzeros,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
o_proj_weight,
o_proj_weight_scale_inv,
o_proj_bias,
o_proj_qzeros,
seq_lens_num,
sm_scale,
num_attention_heads,
num_key_value_heads,
flash_attentiton,
is_decode,
reduce_result,
world_size,
rank,
group_id,
dev_info,
)
if out_single:
return out[0]
return out
def qwen3_fuse_attention_moe_decode(
# attention
hidden_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale_inv: torch.Tensor,
qkv_proj_bias: torch.Tensor,
qkv_proj_qzeros: torch.Tensor,
q_layernorm_weight: torch.Tensor,
k_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale_inv: torch.Tensor,
o_proj_bias: torch.Tensor,
o_proj_qzeros: torch.Tensor,
seq_lens_num: Tuple[int] | List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attentiton: bool,
is_decode: bool,
reduce_result: bool,
# moe
rms_weight: torch.Tensor,
moe_weight_13: torch.Tensor,
moe_weight_2: torch.Tensor,
moe_weight_13_dequat: torch.Tensor,
moe_weight_2_dequant: torch.Tensor,
gate_weight: torch.Tensor,
block_size_13: Tuple[int] | List[int],
block_size_2: Tuple[int] | List[int],
# dist
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
return _torch_vacc.qwen3_fuse_attention_moe_decode(
hidden_states,
none2empty(residual),
hidden_states_norm_weight,
qkv_proj_weight,
qkv_proj_weight_scale_inv,
none2empty(qkv_proj_bias),
none2empty(qkv_proj_qzeros),
q_layernorm_weight,
k_layernorm_weight,
sin_cache,
cos_cache,
slot_mapping,
kv_cache,
block_tables,
block_group_size,
o_proj_weight,
o_proj_weight_scale_inv,
none2empty(o_proj_bias),
none2empty(o_proj_qzeros),
seq_lens_num,
sm_scale,
num_attention_heads,
num_key_value_heads,
flash_attentiton,
is_decode,
reduce_result,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size,
rank,
group_id,
dev_info,
)
def fuse_mtp_stage0(
inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor,
positions: torch.Tensor,
enorm_wegiht: torch.Tensor,
hnorm_wegint: torch.Tensor,
epsilon: float,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_mtp_stage0(
inputs_embeds,
previous_hidden_states,
positions,
enorm_wegiht,
hnorm_wegint,
epsilon,
world_size,
rank,
group_id,
dev_info,
output,
)
def fuse_mtp_allreduce(
inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor,
positions: torch.Tensor,
enorm_wegiht: torch.Tensor,
hnorm_wegint: torch.Tensor,
linear_weight: torch.Tensor,
epsilon: float,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
output: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_mtp_allreduce(
inputs_embeds,
previous_hidden_states,
positions,
enorm_wegiht,
hnorm_wegint,
linear_weight,
epsilon,
world_size,
rank,
group_id,
dev_info,
output,
)
def roll_out(
self: torch.Tensor,
shifts: List[int] | Tuple[int] | int,
dims: List[int] | Tuple[int] | int = [],
output: Optional[torch.Tensor] = None,
):
if isinstance(dims, int):
dims = [dims]
if isinstance(shifts, int):
shifts = [shifts]
return _torch_vacc.roll_out(
self,
shifts,
dims,
output,
)
def fused_experts_int4_prefill(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w13_block_shape: Optional[List[int]] = None,
w2_block_shape: Optional[List[int]] = None,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
warning_message = "[fused_experts]:vacc only support fp8 weights now"
assert a13_scale is None and a2_scale is None, warning_message
assert (
w13_scale is not None and w2_scale is not None
), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}"
assert (
w13_block_shape is not None
), f"{warning_message}, but block_shape is {w13_block_shape}"
# assert (
# not decode_with_batch or hidden_states.size(0) <= 4
# ), "[fused_experts]:vacc only support batch <= 4 when decode"
# topk weights dtype should be same with hidden_states
topk_weights = topk_weights.to(hidden_states.dtype)
# vacc device use int32 for experts_id
topk_ids = topk_ids.to(torch.int32)
# assert (
# block_size0 == block_size1
# ), "quant block shape now support size0 == size1"
return _torch_vacc.fused_experts_int4_prefill(
hidden_states,
w13_weight,
w2_weight,
topk_weights,
topk_ids,
w13_scale,
w2_scale,
a13_scale,
a2_scale,
w13_block_shape,
w2_block_shape,
output_opt,
)
def fuse_bge_embedding_stage1(
input_embeds: torch.Tensor,
positions_ids: torch.Tensor,
positions_embeddings_weight: torch.Tensor,
token_type_ids: torch.Tensor,
token_type_embeddings_weight: torch.Tensor,
layernorm_weight: torch.Tensor,
layernorm_bias: torch.Tensor,
epsilon: float,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return _torch_vacc.fuse_bge_embedding_stage1(
input_embeds,
positions_ids,
positions_embeddings_weight,
token_type_ids,
token_type_embeddings_weight,
layernorm_weight,
layernorm_bias,
epsilon,
output_opt,
)
def l2_norm(
input: torch.Tensor,
epsilon: float,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return _torch_vacc.l2_norm(
input,
epsilon,
output_opt,
)
class BERT_ATTN_STAGE(Enum):
# with reduce, for seqLen <= 2k
FullStage = 0
# without reduce, call reduce outer
AttnOutStage = 1
InterOutStage = 2
def fused_attn_bert_allreduce(
hidden_states: torch.Tensor,
qkv_weight: Optional[torch.Tensor] = None,
qkv_bias: Optional[torch.Tensor] = None,
self_weight: Optional[torch.Tensor] = None,
self_bias: Optional[torch.Tensor] = None,
self_norm_weight: Optional[torch.Tensor] = None,
self_norm_bias: Optional[torch.Tensor] = None,
intermediate_weight: Optional[torch.Tensor] = None,
intermediate_bias: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
output_bias: Optional[torch.Tensor] = None,
output_norm_weight: Optional[torch.Tensor] = None,
output_norm_bias: Optional[torch.Tensor] = None,
dense_out: Optional[torch.Tensor] = None,
seqs: List[int] | Tuple[int] = [],
vnnlBertKind: BERT_ATTN_STAGE = BERT_ATTN_STAGE.FullStage,
sm_scale: float = 1.0,
num_q_heads: int = 1,
num_kv_heads: int = 1,
flash_attention: bool = False,
reduce_result: bool = False,
world_size: int = 1,
rank: int = 0,
group_id: int = 0,
dev_info: List[int] | Tuple[int] = [],
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fused_attn_bert_allreduce(
hidden_states,
qkv_weight,
qkv_bias,
self_weight,
self_bias,
self_norm_weight,
self_norm_bias,
intermediate_weight,
intermediate_bias,
output_weight,
output_bias,
output_norm_weight,
output_norm_bias,
dense_out,
seqs,
vnnlBertKind.value,
sm_scale,
num_q_heads,
num_kv_heads,
flash_attention,
reduce_result,
world_size,
rank,
group_id,
dev_info,
)
def fuse_mlp_vision(
src: torch.Tensor,
weights_13: torch.Tensor,
weights_2: torch.Tensor,
weights_13_bias: Optional[torch.Tensor] = None,
weights_2_bias: Optional[torch.Tensor] = None,
act_type: int = 0,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if weights_13_bias is None:
weights_13_bias = torch.Tensor()
if weights_2_bias is None:
weights_2_bias = torch.Tensor()
return _torch_vacc.fuse_mlp_vision(
src,
weights_13,
weights_2,
weights_13_bias,
weights_2_bias,
act_type,
output_opt,
)
def patch_merger_vision(
src: torch.Tensor,
weights_13: torch.Tensor,
weights_2: torch.Tensor,
weights_13_bias: Optional[torch.Tensor] = None,
weights_2_bias: Optional[torch.Tensor] = None,
act_type: int = 0,
output_opt: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if weights_13_bias is None:
weights_13_bias = torch.Tensor()
if weights_2_bias is None:
weights_2_bias = torch.Tensor()
return _torch_vacc.patch_merger_vision(
src,
weights_13,
weights_2,
weights_13_bias,
weights_2_bias,
act_type,
output_opt,
)