2819 lines
76 KiB
Python
2819 lines
76 KiB
Python
from __future__ import annotations
|
|
from typing import List, Optional, Tuple, Union
|
|
from enum import Enum
|
|
|
|
import torch
|
|
from torch import Generator
|
|
from torch._C._distributed_c10d import ReduceOp
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
|
|
|
|
|
def rms_norm(
|
|
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, output=None
|
|
) -> torch.Tensor:
|
|
r"""Root mean square(RMS) normalization of inputs over last dimension.
|
|
|
|
vacc fused kernal for rms_norm, only for three dimensions of input and sigle dimemsion weight
|
|
|
|
Args:
|
|
input: input tensor with shape (batch, seq_len, hidden size)
|
|
weight: weight tensor with shape (hidden size,)
|
|
eps (float): small value to avoid division by zero. Default: 1e-6
|
|
|
|
Returns:
|
|
Tensor: tensor after applying rms_norm
|
|
"""
|
|
# assert input.dim() == 3, "rms_norm only support the input with dim=3"
|
|
if input.device.type == "vacc":
|
|
return torch.ops.vacc.rms_norm_func(input, weight, eps, output)
|
|
input_dtype = input.dtype
|
|
input = input.to(torch.float32)
|
|
variance = input.pow(2).mean(-1, keepdim=True)
|
|
rsigma = torch.rsqrt(variance + eps)
|
|
normalized_states = input * rsigma
|
|
return weight * normalized_states.to(input_dtype)
|
|
|
|
|
|
rms_norm.apply = rms_norm
|
|
|
|
|
|
def rotate_half(x):
|
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
|
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
|
|
|
|
|
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
x = torch.stack((-x2, x1), dim=-1)
|
|
return x.flatten(-2)
|
|
|
|
|
|
def RotaryPosEmbedding(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: Optional[torch.Tensor] = None,
|
|
sin: Optional[torch.Tensor] = None,
|
|
offset: int = 0,
|
|
mode: str = "neox",
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Performs: Apply rotary positional embedding to input tensor q/k in `sbhd` or s(b*h)d format, where
|
|
s: sequence length
|
|
b: batch size
|
|
h: head num
|
|
d: dim of each head
|
|
|
|
Note: if cos/sin=None, vacc do RoPE in classical mode, generating cos/sin with arg 'base=10000'
|
|
|
|
Args:
|
|
q (Tensor): Input tensor T is of shape [s, b*h, d]
|
|
k (Tensor): Input tensor T is of shape [s, b*h, d]
|
|
where s is sequence lenth; b is batch size; h is heads; d is head dims
|
|
cos optional(Tensor): Cached cosine of the rotary positional embedding tensor
|
|
with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2]
|
|
sin optional(Tensor): Cached sine of the rotary positional embedding tensor.
|
|
with shape [s, 1, d] or [s, d] or [s, 1, d/2] or [s, d/2]
|
|
offset (int): offset of position at cos/sin cache, default=0
|
|
mode (str): 'nexo': rotate half, 'gptj': rotate every two
|
|
|
|
Returns:
|
|
Tuple[Tensor, Tensor]: The tuple of q/k tensor after applying RoPE
|
|
"""
|
|
assert (
|
|
q.dim() == 3 and k.dim() == 3
|
|
), f"the dim of q/k should be 3 but get q:{q.dim()}, k:{k.dim()}"
|
|
assert mode in ["neox", "gptj"], "only support rope mode 'neox' or 'gptj'"
|
|
assert q.dtype == k.dtype, "the dtype should be same"
|
|
assert (cos == None and sin == None) or (
|
|
isinstance(cos, torch.Tensor) and isinstance(sin, torch.Tensor)
|
|
)
|
|
assert q.device.type == "vacc" and k.device.type == "vacc"
|
|
mode_ = 0 if mode == "gptj" else 1
|
|
|
|
if cos is not None and cos.size(0) != q.size(0):
|
|
cos = cos[offset: q.shape[0] + offset, ...]
|
|
if sin is not None and sin.size(0) != q.size(0):
|
|
sin = sin[offset: q.shape[0] + offset, ...]
|
|
|
|
# repeat last dim as same size with input tensor
|
|
# if cos is not None and cos.numel() != 0:
|
|
# assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1)
|
|
# assert cos.dim() == 2 or cos.dim() == 3
|
|
# if cos.size(-1) * 2 == q.size(-1):
|
|
# if mode_ == 0:
|
|
# cos = cos.repeat_interleave(2, dim=-1)
|
|
# sin = sin.repeat_interleave(2, dim=-1)
|
|
# else:
|
|
# cos = torch.cat([cos, cos], dim=-1)
|
|
# sin = torch.cat([sin, sin], dim=-1)
|
|
|
|
if cos is not None and cos.numel() != 0:
|
|
assert cos.size(-1) == q.size(-1) or cos.size(-1) * 2 == q.size(-1)
|
|
assert cos.dim() == 2 or cos.dim() == 3
|
|
if cos.dim() == 2:
|
|
cos = cos.unsqueeze(-2)
|
|
sin = sin.unsqueeze(-2)
|
|
|
|
return torch.ops.vacc.RotaryPosEmbedding_func(q, k, cos, sin, offset, mode_)
|
|
|
|
|
|
RotaryPosEmbedding.apply = RotaryPosEmbedding
|
|
|
|
|
|
def scaled_dot_product_attention(
|
|
# same with torch define
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: torch.Tensor = None,
|
|
dropout_p: float = 0.5,
|
|
is_causal: bool = False,
|
|
scale: float = None,
|
|
# extend
|
|
is_train: bool = True,
|
|
recompute: bool = False,
|
|
flash_attention: bool = False,
|
|
sm_scale: float = -1,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d],
|
|
kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d].
|
|
support float16, bfloat16, float32
|
|
sq: sequence length of query
|
|
sk: sequence length of key, value
|
|
b: batch size
|
|
h: head num
|
|
d: dim of each head
|
|
|
|
Args:
|
|
query (Tensor): Input tensor T is of shape [sq, b*h, d]
|
|
key (Tensor): Input tensor T is of shape [sk, b*h, d]
|
|
value (Tensor): Input tensor T is of shape [sk, b*h, d]
|
|
attn_mask (Tensor): masked bool tensor of shape [1, sq, sk] or [sq, sk]
|
|
dropout_p (float): the probability of dropout
|
|
is_causal (bool): accelerate compute when mask is causal type
|
|
scale (float): not use, default is 1/sqrt(dim)
|
|
is_train (bool): train mode or eval mode
|
|
recompute (bool): whether to recompute for reducing memory usage, is valid
|
|
when is_train=True
|
|
flash_attention (bool): using flash attention, that cat support large sequence
|
|
|
|
Returns:
|
|
Tensor: the tensor after self attention
|
|
"""
|
|
assert (
|
|
query.dtype == key.dtype and key.dtype == value.dtype
|
|
), "types of qkv should be same"
|
|
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
|
|
|
if attn_mask == None:
|
|
attn_mask = torch.Tensor()
|
|
|
|
out = torch.ops.vacc.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
dropout_p,
|
|
is_train,
|
|
recompute,
|
|
is_causal,
|
|
flash_attention,
|
|
sm_scale,
|
|
)
|
|
return out[0]
|
|
|
|
|
|
scaled_dot_product_attention.apply = scaled_dot_product_attention
|
|
|
|
|
|
def swiglu(
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r"""Perferms:
|
|
x = torch.chunk(x, 2, dim=-1)
|
|
return F.silu(x[0]) * x[1]
|
|
|
|
Args:
|
|
x: Input tensor, support float/float16/bfloat16, 3 dims
|
|
|
|
Return:
|
|
Tensor: the out of swiglu
|
|
"""
|
|
|
|
return torch.ops.vacc.swiglu(x)
|
|
|
|
|
|
swiglu.apply = swiglu
|
|
|
|
|
|
def scaled_dot_product_attention_cp_forward(
|
|
# same with torch define
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: torch.Tensor = None,
|
|
dropout_p: float = 0.5,
|
|
is_causal: bool = False,
|
|
# extend
|
|
is_train: bool = True,
|
|
):
|
|
r"""Performs: Apply attention operation for q tensor with fixed shape[sq, b*h, d],
|
|
kv with fixed shape[sk, b*h, d], while the output is [sq, b*h, d].
|
|
sq: sequence length of query
|
|
sk: sequence length of key, value
|
|
b: batch size
|
|
h: head num
|
|
d: dim of each head
|
|
|
|
Args:
|
|
query (Tensor): Input fp16/bfp16 tensor T is of shape [sq, b*h, d]
|
|
key (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d]
|
|
value (Tensor): Input fp16/bfp16 tensor T is of shape [sk, b*h, d]
|
|
attn_mask (Tensor): masked bool tensor of shape [1, sq, sk]
|
|
dropout_p (float): the probability of dropout
|
|
is_causal (bool): accelerate compute when mask is causal type
|
|
scale (float): not use, default is 1/sqrt(dim)
|
|
is_train (bool): train mode or eval mode
|
|
|
|
Returns: list of tensor, size=4
|
|
[attention result,
|
|
max of QK^t with shape of (b*h, sq, d)],
|
|
sum of exp(QK^t) with shape of (b*h, sq, d),
|
|
seed(used in backward)]
|
|
"""
|
|
assert (
|
|
query.dtype == key.dtype and key.dtype == value.dtype
|
|
), "types of qkv should be same"
|
|
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
|
assert (
|
|
is_causal or attn_mask != None
|
|
), "attn_mask should be valid when is_causal is False"
|
|
|
|
if attn_mask is None:
|
|
attn_mask = torch.Tensor()
|
|
else:
|
|
assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size(
|
|
0
|
|
), "attn_mask size should be (..., sq, sk)"
|
|
out = _torch_vacc.scaled_dot_product_attention_cp_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
dropout_p,
|
|
is_train,
|
|
is_causal,
|
|
)
|
|
return out
|
|
|
|
|
|
def scaled_dot_product_attention_cp_backward(
|
|
grad_output: torch.Tensor,
|
|
attn_out: torch.Tensor,
|
|
max_of_row: torch.Tensor,
|
|
sum_of_row: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor],
|
|
seed: torch.Tensor,
|
|
dropout_p: float = 0.5,
|
|
is_causal: bool = False,
|
|
is_train: bool = True,
|
|
):
|
|
assert (
|
|
query.dtype == key.dtype and key.dtype == value.dtype
|
|
), "types of qkv should be same"
|
|
assert query.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
|
assert (
|
|
is_causal or attn_mask != None
|
|
), "attn_mask should be valid when is_causal is False"
|
|
|
|
if attn_mask == None:
|
|
attn_mask = torch.Tensor()
|
|
else:
|
|
assert attn_mask.size(-2) == query.size(0) and attn_mask.size(-1) == key.size(
|
|
0
|
|
), "attn_mask size should be (..., sq, sk)"
|
|
|
|
out = _torch_vacc.scaled_dot_product_attention_cp_backward(
|
|
grad_output,
|
|
attn_out,
|
|
max_of_row,
|
|
sum_of_row,
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
dropout_p,
|
|
is_train,
|
|
is_causal,
|
|
seed,
|
|
)
|
|
return out
|
|
|
|
|
|
def paged_attention(
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
seq_len: torch.Tensor,
|
|
sm_scale: float = -1,
|
|
out: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""Performs: Apply attention operation for q tensor with fixed shape[batch, b*h, d],
|
|
key_cache/value_cache with fixed shape[nb, bs, h, d], while the output is [batch, h, d].
|
|
batch: token batch
|
|
nb: num_blocks
|
|
bs: block_size
|
|
h: heads num
|
|
d: dim of each head
|
|
|
|
Args:
|
|
query (Tensor): Input fp16/bfp16 tensor T is of shape [batch, h, d]
|
|
key_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d]
|
|
value_cache (Tensor): Input fp16/bfp16 tensor T is of shape [nb, bs, h, d]
|
|
block_table (Tensor): k/v map of cache
|
|
seq_len (Tensor): sequence lenth of each batch
|
|
out (optional Tensor): output tensor if given, otherwise return a new tensor
|
|
|
|
Returns:
|
|
Tensor: the tensor after page attention
|
|
"""
|
|
|
|
return _torch_vacc.paged_attention(
|
|
query, key_cache, value_cache, block_table, seq_len, out, sm_scale
|
|
)
|
|
|
|
|
|
def reshape_and_cache_attention(
|
|
src: torch.Tensor,
|
|
cached: torch.Tensor,
|
|
block_mapping: torch.Tensor,
|
|
):
|
|
torch.ops.vacc.reshape_and_cache_attention(src, cached, block_mapping)
|
|
|
|
|
|
def concat_and_cache_attention(
|
|
src: torch.Tensor,
|
|
src1: torch.Tensor,
|
|
cached: torch.Tensor,
|
|
block_mapping: torch.Tensor,
|
|
):
|
|
_torch_vacc.concat_and_cache_attention(src, src1, cached, block_mapping)
|
|
|
|
|
|
def w8a8_block_fp8_matmul(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor],
|
|
weight_scale: Optional[torch.Tensor],
|
|
block_size: List[int],
|
|
output: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
):
|
|
return _torch_vacc.w8a8_block_fp8_matmul(
|
|
input,
|
|
weight,
|
|
None,
|
|
weight_scale,
|
|
block_size,
|
|
output,
|
|
)
|
|
|
|
|
|
def w8a8_block_fp8_linear(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor],
|
|
weight_scale: Optional[torch.Tensor],
|
|
block_size: List[int],
|
|
output: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
):
|
|
return _torch_vacc.w8a8_block_fp8_matmul(
|
|
input, weight.T, None, weight_scale.T, [block_size[1], block_size[0]], output
|
|
)
|
|
|
|
|
|
def moe_expert_token_group_reassign(
|
|
topk_idx: torch.Tensor,
|
|
topk_val: torch.Tensor,
|
|
expert_num_: int,
|
|
gp_size_: int = 16,
|
|
gp_num_align_: int = 4,
|
|
):
|
|
return _torch_vacc.moe_expert_token_group_reassign(
|
|
topk_idx, topk_val, expert_num_, gp_size_, gp_num_align_
|
|
)
|
|
|
|
|
|
def fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w13_weight: torch.Tensor,
|
|
w2_weight: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
use_fp8_w8a8: bool = True,
|
|
w13_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a13_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[List[int]] = None,
|
|
decode_with_batch: bool = False,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
warning_message = "[fused_experts]:vacc only support fp8 weights now"
|
|
assert a13_scale is None and a2_scale is None, warning_message
|
|
assert use_fp8_w8a8, f"{warning_message}, but use_fp8_w8a8 is {use_fp8_w8a8}"
|
|
assert (
|
|
w13_scale is not None and w2_scale is not None
|
|
), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}"
|
|
assert (
|
|
block_shape is not None
|
|
), f"{warning_message}, but block_shape is {block_shape}"
|
|
# assert (
|
|
# not decode_with_batch or hidden_states.size(0) <= 4
|
|
# ), "[fused_experts]:vacc only support batch <= 4 when decode"
|
|
|
|
if hidden_states.device.type == "vacc":
|
|
# topk weights dtype should be same with hidden_states
|
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
|
# vacc device use int32 for experts_id
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
hidden_dims, inter_dims = w13_weight.shape[1], w13_weight.shape[2]
|
|
hidden_blocks, inter_blocks = w13_scale.shape[1], w13_scale.shape[2]
|
|
block_size0, block_size1 = (
|
|
hidden_dims // hidden_blocks,
|
|
inter_dims // inter_blocks,
|
|
)
|
|
# assert (
|
|
# block_size0 == block_size1
|
|
# ), "quant block shape now support size0 == size1"
|
|
return _torch_vacc.fused_experts(
|
|
hidden_states,
|
|
w13_weight,
|
|
w2_weight,
|
|
topk_weights,
|
|
topk_ids,
|
|
use_fp8_w8a8,
|
|
w13_scale,
|
|
w2_scale,
|
|
a13_scale,
|
|
a2_scale,
|
|
[block_size0, block_size1],
|
|
decode_with_batch,
|
|
output_opt,
|
|
)
|
|
|
|
from .custom_ops_cpu import fused_experts as fused_experts_default_method
|
|
|
|
return fused_experts_default_method(
|
|
hidden_states,
|
|
w13_weight,
|
|
w2_weight,
|
|
topk_weights,
|
|
topk_ids,
|
|
use_fp8_w8a8,
|
|
w13_scale,
|
|
w2_scale,
|
|
a13_scale,
|
|
a2_scale,
|
|
block_shape,
|
|
decode_with_batch,
|
|
)
|
|
|
|
|
|
# NOTE: w13, w2 using linear format
|
|
def fused_mlp_fp8(
|
|
hidden_states: torch.Tensor,
|
|
w13_weight: torch.Tensor,
|
|
w2_weight: torch.Tensor,
|
|
use_fp8_w8a8: bool = True,
|
|
w13_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a13_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape_w13: Optional[List[int]] = None,
|
|
block_shape_w2: Optional[List[int]] = None,
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
assert a13_scale is None
|
|
assert a2_scale is None
|
|
assert w13_scale is not None
|
|
assert w2_scale is not None
|
|
assert block_shape_w13 is not None
|
|
assert block_shape_w2 is not None
|
|
|
|
if hidden_states.device.type == "vacc":
|
|
return _torch_vacc.fused_mlp(
|
|
hidden_states,
|
|
w13_weight,
|
|
w2_weight,
|
|
use_fp8_w8a8,
|
|
w13_scale,
|
|
w2_scale,
|
|
a13_scale,
|
|
a2_scale,
|
|
block_shape_w13,
|
|
block_shape_w2,
|
|
output,
|
|
)
|
|
|
|
from .custom_ops_cpu import fused_mlp_mm_fp8 as fused_mlp_mm_fp8_default_method
|
|
|
|
return fused_mlp_mm_fp8_default_method(
|
|
hidden_states,
|
|
w13_weight.T,
|
|
w2_weight.T,
|
|
use_fp8_w8a8,
|
|
w13_scale.T,
|
|
w2_scale.T,
|
|
a13_scale,
|
|
a2_scale,
|
|
list(block_shape_w13)[::-1],
|
|
list(block_shape_w2)[::-1],
|
|
)
|
|
|
|
|
|
def fused_mlp_mm_fp8(
|
|
hidden_states: torch.Tensor,
|
|
w13_weight: torch.Tensor,
|
|
w2_weight: torch.Tensor,
|
|
use_fp8_w8a8: bool = True,
|
|
w13_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a13_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape_w13: Optional[List[int]] = None,
|
|
block_shape_w2: Optional[List[int]] = None,
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
return fused_mlp_fp8(
|
|
hidden_states,
|
|
w13_weight.T,
|
|
w2_weight.T,
|
|
use_fp8_w8a8,
|
|
w13_scale.T,
|
|
w2_scale.T,
|
|
a13_scale,
|
|
a2_scale,
|
|
list(block_shape_w13)[::-1],
|
|
list(block_shape_w2)[::-1],
|
|
output,
|
|
)
|
|
|
|
|
|
def fused_moe_preprocess(
|
|
gating_output,
|
|
bias,
|
|
num_expert_group=8,
|
|
num_limited_group=4,
|
|
):
|
|
return _torch_vacc.fused_moe_preprocess(
|
|
gating_output, bias, num_expert_group, num_limited_group
|
|
)
|
|
|
|
|
|
def fused_residual_rmsnorm(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
epsilon: float = 1e-6,
|
|
output: Optional[torch.Tensor] = None,
|
|
residual_out: Optional[torch.Tensor] = None,
|
|
):
|
|
# out = _torch_vacc.fused_residual_rmsnorm(input, weight, residual, epsilon, inplace)
|
|
# if len(out) == 1:
|
|
# return out[0]
|
|
# return out
|
|
|
|
# TODO: VNNL support optional residual
|
|
input = input.contiguous()
|
|
weight = weight.contiguous()
|
|
if residual is None:
|
|
return rms_norm(input, weight, epsilon)
|
|
else:
|
|
return torch.ops.vacc.fused_residual_rmsnorm(
|
|
input,
|
|
weight,
|
|
residual,
|
|
epsilon,
|
|
output,
|
|
residual_out,
|
|
)
|
|
|
|
|
|
def parallel_embedding(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
org_vocab_start_index: int,
|
|
org_vocab_end_index: int,
|
|
num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int,
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
return _torch_vacc.parallel_embedding(
|
|
input,
|
|
weight,
|
|
org_vocab_start_index,
|
|
org_vocab_end_index,
|
|
num_org_vocab_padding,
|
|
added_vocab_start_index,
|
|
added_vocab_end_index,
|
|
output,
|
|
)
|
|
|
|
|
|
def fused_mla(
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
W_Q: torch.Tensor,
|
|
W_UK: torch.Tensor,
|
|
W_QR: torch.Tensor,
|
|
kv_a_proj_weight_scale_inv: torch.Tensor,
|
|
kv_a_proj_weight: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
W_UV: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
q_a_proj_blocksize: Tuple[int] | List[int],
|
|
kv_a_proj_blocksize: Tuple[int] | List[int],
|
|
o_proj_blocksize: Tuple[int] | List[int],
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
):
|
|
# TODO: CHECK
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fused_mla(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
W_Q,
|
|
W_UK,
|
|
W_QR,
|
|
kv_a_proj_weight_scale_inv,
|
|
kv_a_proj_weight,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
W_UV,
|
|
o_proj_weight_scale_inv,
|
|
o_proj_weight,
|
|
q_a_proj_blocksize,
|
|
kv_a_proj_blocksize,
|
|
o_proj_blocksize,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fused_mla_v2(
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
w_q: torch.Tensor,
|
|
w_q_scale: torch.Tensor,
|
|
w_uk: torch.Tensor,
|
|
w_uk_scale: torch.Tensor,
|
|
w_qr: torch.Tensor,
|
|
w_qr_scale: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: torch.Tensor,
|
|
w_uv_scale: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool=False,
|
|
):
|
|
# TODO: CHECk
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fused_mla_v2(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fused_mla_prefill_stage0(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_a_proj_weight: torch.Tensor,
|
|
qkv_a_proj_weight_scale_inv: torch.Tensor,
|
|
):
|
|
# TODO: CHECk
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
out = _torch_vacc.fused_mla_prefill_stage0(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
qkv_a_proj_weight,
|
|
qkv_a_proj_weight_scale_inv,
|
|
)
|
|
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fused_mla_prefill_stage1(
|
|
qkv_a: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
q_proj_weight: torch.Tensor,
|
|
q_proj_weight_scale_inv: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
kv_b_proj_weight: torch.Tensor,
|
|
kv_b_proj_weight_scale_inv: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
seq_lens_num: List[int],
|
|
sm_scale: float,
|
|
num_head: int,
|
|
mla_out_tensor: Optional[torch.Tensor] = None
|
|
):
|
|
out = _torch_vacc.fused_mla_prefill_stage1(
|
|
qkv_a,
|
|
q_a_layernorm_weight,
|
|
q_proj_weight,
|
|
q_proj_weight_scale_inv,
|
|
kv_a_layernorm_weight,
|
|
kv_b_proj_weight,
|
|
kv_b_proj_weight_scale_inv,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens_num,
|
|
sm_scale,
|
|
num_head,
|
|
mla_out_tensor
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def fused_mla_allreduce(
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
W_Q: torch.Tensor,
|
|
W_UK: torch.Tensor,
|
|
W_QR: torch.Tensor,
|
|
kv_a_proj_weight_scale_inv: torch.Tensor,
|
|
kv_a_proj_weight: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
W_UV: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
q_a_proj_blocksize: Tuple[int] | List[int],
|
|
kv_a_proj_blocksize: Tuple[int] | List[int],
|
|
o_proj_blocksize: Tuple[int] | List[int],
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
red_op_type: int,
|
|
world_size: int,
|
|
rank: int,
|
|
root_rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert red_op_type == 0, "all_reduce only support red_op_type=0"
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fused_mla_allreduce(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
W_Q,
|
|
W_UK,
|
|
W_QR,
|
|
kv_a_proj_weight_scale_inv,
|
|
kv_a_proj_weight,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
W_UV,
|
|
o_proj_weight_scale_inv,
|
|
o_proj_weight,
|
|
q_a_proj_blocksize,
|
|
kv_a_proj_blocksize,
|
|
o_proj_blocksize,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
red_op_type,
|
|
world_size,
|
|
rank,
|
|
root_rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fused_mla_allreduce_v2(
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
w_q: torch.Tensor,
|
|
w_q_scale: torch.Tensor,
|
|
w_uk: torch.Tensor,
|
|
w_uk_scale: torch.Tensor,
|
|
w_qr: torch.Tensor,
|
|
w_qr_scale: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: torch.Tensor,
|
|
w_uv_scale: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert world_size > 0, "fused_mla_allreduce_v2 only support world_size > 0"
|
|
assert rank >= 0, "fused_mla_allreduce_v2 only support rank >= 0"
|
|
if not dev_info:
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fused_mla_allreduce_v2(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fused_mla_prefill_stage0_allreduce(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_a_proj_weight: torch.Tensor,
|
|
qkv_a_proj_weight_scale_inv: torch.Tensor,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECk
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
out = _torch_vacc.fused_mla_prefill_stage0_allreduce(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
qkv_a_proj_weight,
|
|
qkv_a_proj_weight_scale_inv,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def all_reduce(
|
|
input: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
group_id: int,
|
|
dev_info: List[int],
|
|
red_op_type: int = 0,
|
|
):
|
|
assert input.device.type == "vacc", "all_reduce only support VACC"
|
|
assert red_op_type == 0, "all_reduce only support red_op_type=0"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
return _torch_vacc.all_reduce(
|
|
input, rank, world_size, 0, group_id, dev_info, red_op_type
|
|
)
|
|
|
|
|
|
def all_gather(
|
|
input: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
group_id: int,
|
|
dev_info: List[int],
|
|
):
|
|
assert input.device.type == "vacc", "all_gather only support VACC"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
return _torch_vacc.all_gather(input, rank, world_size, 0, group_id, dev_info, 0)
|
|
|
|
|
|
def broadcast(
|
|
input: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
root_rank: int,
|
|
group_id: int,
|
|
dev_info: List[int],
|
|
):
|
|
assert input.device.type == "vacc", "broadcast only support VACC"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
return _torch_vacc.all_gather(
|
|
input, rank, world_size, root_rank, group_id, dev_info, 1
|
|
)
|
|
|
|
|
|
def fused_mlp_moe_with_rmsnorm(
|
|
hidden_states: torch.Tensor,
|
|
rms_residual: torch.Tensor,
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
moe_weight_13: torch.Tensor,
|
|
moe_weight_2: torch.Tensor,
|
|
moe_weight_scale_13: torch.Tensor,
|
|
moe_weight_scale_2: torch.Tensor,
|
|
mm_weight: torch.Tensor,
|
|
moe_bias: torch.Tensor,
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
moe_block_size_w13: List[int] | Tuple[int],
|
|
moe_block_size_w2: List[int] | Tuple[int],
|
|
):
|
|
return _torch_vacc.fused_mlp_moe_with_rmsnorm(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
moe_weight_13,
|
|
moe_weight_2,
|
|
moe_weight_scale_13,
|
|
moe_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
moe_block_size_w13,
|
|
moe_block_size_w2,
|
|
)
|
|
|
|
|
|
def fused_mlp_with_rmsnorm(
|
|
hidden_states: torch.Tensor,
|
|
rms_residual: torch.Tensor,
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
):
|
|
return _torch_vacc.fused_mlp_with_rmsnorm(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
)
|
|
|
|
|
|
def fuse_moe_decode_v2_allreduce(
|
|
hidden_states: torch.Tensor,
|
|
rms_residual: torch.Tensor,
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
moe_weight_13: torch.Tensor,
|
|
moe_weight_2: torch.Tensor,
|
|
moe_weight_scale_13: torch.Tensor,
|
|
moe_weight_scale_2: torch.Tensor,
|
|
mm_weight: torch.Tensor,
|
|
moe_bias: torch.Tensor,
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
moe_block_size_w13: List[int] | Tuple[int],
|
|
moe_block_size_w2: List[int] | Tuple[int],
|
|
red_op_type: int,
|
|
world_size: int,
|
|
rank: int,
|
|
root_rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_moe_decode_v2_allreduce only support VACC"
|
|
assert red_op_type == 0, "fuse_moe_decode_v2_allreduce only support red_op_type=0"
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fuse_moe_decode_v2_allreduce(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
moe_weight_13,
|
|
moe_weight_2,
|
|
moe_weight_scale_13,
|
|
moe_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
moe_block_size_w13,
|
|
moe_block_size_w2,
|
|
red_op_type,
|
|
world_size,
|
|
rank,
|
|
root_rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
|
|
|
|
def topk_topp(
|
|
logits: torch.Tensor,
|
|
p: torch.Tensor,
|
|
k: torch.Tensor,
|
|
):
|
|
return _torch_vacc.topk_topp(logits, p, k)
|
|
|
|
|
|
def fused_mlp_allreduce(
|
|
hidden_states: torch.Tensor,
|
|
rms_residual: torch.Tensor,
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
red_op_type: int,
|
|
world_size: int,
|
|
rank: int,
|
|
root_rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
assert hidden_states.device.type == "vacc", "fused_mlp_allreduce only support VACC"
|
|
assert red_op_type == 0, "fused_mlp_allreduce only support red_op_type=0"
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fused_mlp_allreduce(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
red_op_type,
|
|
world_size,
|
|
rank,
|
|
root_rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
|
|
|
|
def mla_matmul_scale(
|
|
input: torch.Tensor, weight: torch.Tensor, scale: float, align_seq_len: int = 1024
|
|
):
|
|
return _torch_vacc.mla_matmul_scale(
|
|
input,
|
|
weight,
|
|
scale,
|
|
align_seq_len,
|
|
)
|
|
|
|
|
|
def mla_matmul(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
):
|
|
return _torch_vacc.mla_matmul(
|
|
input,
|
|
weight,
|
|
)
|
|
|
|
|
|
def ds3_sampler(
|
|
src, p, k, temperatures, exponential_enable, generator: Optional[Generator] = None
|
|
)-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
return _torch_vacc.ds3_sampler(
|
|
src, p, k, temperatures, exponential_enable, generator
|
|
)
|
|
|
|
def sampler_v1(
|
|
src, p, k, temperatures, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]] = {}
|
|
)->Tuple[torch.Tensor, torch.Tensor]:
|
|
if not isinstance(generators, dict):
|
|
raise TypeError(f"generator must be a dictionary, got {type(generators)}")
|
|
return _torch_vacc.sampler_v1(
|
|
src, p, k, temperatures, all_greedy, all_random, generators
|
|
)
|
|
|
|
def apply_penalties(
|
|
src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens,
|
|
frequency_penalties, presence_penalties, is_first_calculation
|
|
)->Tuple[torch.Tensor]:
|
|
|
|
return _torch_vacc.apply_penalties(src_logits, src_tokens, buf_bin_counts, vocab_size, num_tokens, frequency_penalties, presence_penalties, is_first_calculation)
|
|
|
|
def rejection_sampler(
|
|
target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator: Optional[Generator] = None
|
|
):
|
|
return _torch_vacc.rejection_sampler(
|
|
target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, gen_seed, generator
|
|
)
|
|
|
|
def rejection_sampler_update_hidden_states(
|
|
hidden_states, accepted_index
|
|
):
|
|
return _torch_vacc.rejection_sampler_update_hidden_states(
|
|
hidden_states, accepted_index
|
|
)
|
|
|
|
def rejection_sampler_v1(
|
|
target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators: dict[int, Optional[torch.Generator]]
|
|
):
|
|
return _torch_vacc.rejection_sampler_v1(
|
|
target_logits, draft_token_ids, bonus_token_ids, temperature, top_p, top_k, all_greedy, all_random, generators
|
|
)
|
|
|
|
def fused_matmul_allgather(
|
|
input: torch.Tensor,
|
|
mat2: torch.Tensor,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] = None,
|
|
) -> torch.Tensor:
|
|
assert input.device.type == "vacc", "fused_matmul_allgather only support VACC"
|
|
assert mat2.device.type == "vacc", "fused_matmul_allgather only support VACC"
|
|
assert 2 == input.ndim, "fused_matmul_allgather: 'input' must be 2D tensor"
|
|
assert 2 == mat2.ndim, "fused_matmul_allgather: 'mat2' must be 2D tensor"
|
|
assert (
|
|
input.shape[-1] == mat2.shape[0]
|
|
), "fused_matmul_allgather: dim1 of 'input' must be equal to dim0 of 'mat2'"
|
|
assert world_size > 0, "world_size must be greater than 0"
|
|
|
|
if dev_info is None or 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fused_matmul_allgather(
|
|
input,
|
|
mat2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
|
|
|
|
def fuse_moe_prefill_stage0(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
rms_hidden_state_opt: Optional[torch.Tensor] = None,
|
|
mlp_hidden_state_opt: Optional[torch.Tensor] = None,
|
|
topk_ids_opt: Optional[torch.Tensor] = None,
|
|
topk_weight_opt: Optional[torch.Tensor] = None,
|
|
):
|
|
return _torch_vacc.fuse_moe_prefill_stage0(
|
|
hidden_states,
|
|
rms_residual,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
rms_hidden_state_opt,
|
|
mlp_hidden_state_opt,
|
|
topk_ids_opt,
|
|
topk_weight_opt,
|
|
)
|
|
|
|
|
|
def fuse_mla_mlp_v2_allreduce_decode(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
w_q: torch.Tensor,
|
|
w_q_scale: torch.Tensor,
|
|
w_uk: torch.Tensor,
|
|
w_uk_scale: torch.Tensor,
|
|
w_qr: torch.Tensor,
|
|
w_qr_scale: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: torch.Tensor,
|
|
w_uv_scale: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# mlp weight
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
# mlp params
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_mlp_v2_allreduce_decode only support VACC"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
# if out_single:
|
|
# return out[0]
|
|
return out
|
|
|
|
|
|
def fuse_mla_moe_v2_allreduce_decode(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
q_a_proj_weight: torch.Tensor,
|
|
q_a_proj_weight_scale_inv: torch.Tensor,
|
|
q_a_layernorm_weight: torch.Tensor,
|
|
w_q: torch.Tensor,
|
|
w_q_scale: torch.Tensor,
|
|
w_uk: torch.Tensor,
|
|
w_uk_scale: torch.Tensor,
|
|
w_qr: torch.Tensor,
|
|
w_qr_scale: torch.Tensor,
|
|
kv_a_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: torch.Tensor,
|
|
w_uv_scale: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# moe weight
|
|
rms_weight: torch.Tensor,
|
|
mlp_weight_13: torch.Tensor,
|
|
mlp_weight_2: torch.Tensor,
|
|
mlp_weight_scale_13: torch.Tensor,
|
|
mlp_weight_scale_2: torch.Tensor,
|
|
moe_weight_13: torch.Tensor,
|
|
moe_weight_2: torch.Tensor,
|
|
moe_weight_scale_13: torch.Tensor,
|
|
moe_weight_scale_2: torch.Tensor,
|
|
mm_weight: torch.Tensor,
|
|
moe_bias: torch.Tensor,
|
|
# moe params
|
|
mlp_block_size_w13: Tuple[int] | List[int],
|
|
mlp_block_size_w2: Tuple[int] | List[int],
|
|
moe_block_size_w13: Tuple[int] | List[int],
|
|
moe_block_size_w2: Tuple[int] | List[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_moe_v2_allreduce_decode only support VACC"
|
|
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
moe_weight_13,
|
|
moe_weight_2,
|
|
moe_weight_scale_13,
|
|
moe_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
moe_block_size_w13,
|
|
moe_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
# if out_single:
|
|
# return out[0]
|
|
return out
|
|
|
|
|
|
# ! register fake function for cpp imple custom op
|
|
@torch.library.register_fake("vacc::RotaryPosEmbedding_func")
|
|
def _(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: Optional[torch.Tensor] = None,
|
|
sin: Optional[torch.Tensor] = None,
|
|
offset: int = 0,
|
|
mode: str = "neox",
|
|
):
|
|
return [torch.empty_like(q), torch.empty_like(k)]
|
|
|
|
|
|
@torch.library.register_fake("vacc::reshape_and_cache_attention")
|
|
def _(
|
|
src: torch.Tensor,
|
|
cached: torch.Tensor,
|
|
block_mapping: torch.Tensor,
|
|
):
|
|
pass
|
|
|
|
|
|
@torch.library.register_fake("vacc::scaled_dot_product_attention")
|
|
def _(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
dropout_p,
|
|
is_train,
|
|
recompute,
|
|
is_causal,
|
|
flash_attention,
|
|
sm_scale
|
|
):
|
|
return [torch.empty(size=(query.size()[0], query.size()[1], key.size()[2]), device=query.device, dtype=query.dtype)]
|
|
|
|
|
|
@torch.library.register_fake("vacc::rms_norm_func")
|
|
def _(
|
|
input: torch.Tensor, weight: torch.Tensor, eps: float, output: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(input)
|
|
|
|
|
|
@torch.library.register_fake("vacc::fused_residual_rmsnorm")
|
|
def _(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
epsilon: float = 1e-6,
|
|
output: Optional[torch.Tensor] = None,
|
|
residual_out: Optional[torch.Tensor] = None,
|
|
):
|
|
return [torch.empty_like(input), torch.empty_like(input)]
|
|
|
|
|
|
@torch.library.register_fake("vacc::swiglu")
|
|
def _(
|
|
self: torch.Tensor
|
|
):
|
|
shape = list(self.shape)
|
|
shape[-1] = shape[-1] // 2
|
|
return torch.empty(size=shape, dtype=self.dtype, device=self.device)
|
|
|
|
|
|
def fuse_mla_mlp_v2_allreduce_decode_layers(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
hidden_states_norm_weight: List[torch.Tensor],
|
|
q_a_proj_weight: List[torch.Tensor],
|
|
q_a_proj_weight_scale_inv: List[torch.Tensor],
|
|
q_a_layernorm_weight: List[torch.Tensor],
|
|
w_q: List[torch.Tensor],
|
|
w_q_scale: List[torch.Tensor],
|
|
w_uk: List[torch.Tensor],
|
|
w_uk_scale: List[torch.Tensor],
|
|
w_qr: List[torch.Tensor],
|
|
w_qr_scale: List[torch.Tensor],
|
|
kv_a_layernorm_weight: List[torch.Tensor],
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: List[torch.Tensor],
|
|
w_uv_scale: List[torch.Tensor],
|
|
o_proj_weight: List[torch.Tensor],
|
|
o_proj_weight_scale_inv: List[torch.Tensor],
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# mlp weight
|
|
rms_weight: List[torch.Tensor],
|
|
mlp_weight_13: List[torch.Tensor],
|
|
mlp_weight_2: List[torch.Tensor],
|
|
mlp_weight_scale_13: List[torch.Tensor],
|
|
mlp_weight_scale_2: List[torch.Tensor],
|
|
# mlp params
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_mlp_v2_allreduce_decode_layers only support VACC"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
if residual is None:
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
return out
|
|
|
|
|
|
def fuse_mla_mlp_v2_allreduce_decode_layers_v2(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# mlp weight
|
|
# mlp params
|
|
mlp_block_size_w13: List[int] | Tuple[int],
|
|
mlp_block_size_w2: List[int] | Tuple[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_mlp_v2_allreduce_decode_layers_v2 only support VACC"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
if residual is None:
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_mla_mlp_v2_allreduce_decode_layers(
|
|
hidden_states,
|
|
residual,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
return out
|
|
|
|
def fuse_mla_moe_v2_allreduce_decode_layers(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
hidden_states_norm_weight: List[torch.Tensor],
|
|
q_a_proj_weight: List[torch.Tensor],
|
|
q_a_proj_weight_scale_inv: List[torch.Tensor],
|
|
q_a_layernorm_weight: List[torch.Tensor],
|
|
w_q: List[torch.Tensor],
|
|
w_q_scale: List[torch.Tensor],
|
|
w_uk: List[torch.Tensor],
|
|
w_uk_scale: List[torch.Tensor],
|
|
w_qr: List[torch.Tensor],
|
|
w_qr_scale: List[torch.Tensor],
|
|
kv_a_layernorm_weight: List[torch.Tensor],
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
w_uv: List[torch.Tensor],
|
|
w_uv_scale: List[torch.Tensor],
|
|
o_proj_weight: List[torch.Tensor],
|
|
o_proj_weight_scale_inv: List[torch.Tensor],
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# moe weight
|
|
rms_weight: List[torch.Tensor],
|
|
mlp_weight_13: List[torch.Tensor],
|
|
mlp_weight_2: List[torch.Tensor],
|
|
mlp_weight_scale_13: List[torch.Tensor],
|
|
mlp_weight_scale_2: List[torch.Tensor],
|
|
moe_weight_13: List[torch.Tensor],
|
|
moe_weight_2: List[torch.Tensor],
|
|
moe_weight_scale_13: List[torch.Tensor],
|
|
moe_weight_scale_2: List[torch.Tensor],
|
|
mm_weight: List[torch.Tensor],
|
|
moe_bias: List[torch.Tensor],
|
|
# moe params
|
|
mlp_block_size_w13: Tuple[int] | List[int],
|
|
mlp_block_size_w2: Tuple[int] | List[int],
|
|
moe_block_size_w13: Tuple[int] | List[int],
|
|
moe_block_size_w2: Tuple[int] | List[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_moe_v2_allreduce_decode_layers only support VACC"
|
|
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
q_a_proj_weight,
|
|
q_a_proj_weight_scale_inv,
|
|
q_a_layernorm_weight,
|
|
w_q,
|
|
w_q_scale,
|
|
w_uk,
|
|
w_uk_scale,
|
|
w_qr,
|
|
w_qr_scale,
|
|
kv_a_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
w_uv,
|
|
w_uv_scale,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
rms_weight,
|
|
mlp_weight_13,
|
|
mlp_weight_2,
|
|
mlp_weight_scale_13,
|
|
mlp_weight_scale_2,
|
|
moe_weight_13,
|
|
moe_weight_2,
|
|
moe_weight_scale_13,
|
|
moe_weight_scale_2,
|
|
mm_weight,
|
|
moe_bias,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
moe_block_size_w13,
|
|
moe_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
def fuse_mla_moe_v2_allreduce_decode_layers_v2(
|
|
# input
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
# mla weight
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
# mla params
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
head_num: int,
|
|
flash_attention: bool,
|
|
# moe params
|
|
mlp_block_size_w13: Tuple[int] | List[int],
|
|
mlp_block_size_w2: Tuple[int] | List[int],
|
|
moe_block_size_w13: Tuple[int] | List[int],
|
|
moe_block_size_w2: Tuple[int] | List[int],
|
|
# vccl info
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
# TODO: CHECK
|
|
assert (
|
|
hidden_states.device.type == "vacc"
|
|
), "fuse_mla_moe_v2_allreduce_decode_layers_v2 only support VACC"
|
|
# assert red_op_type == 0, "all_reduce only support red_op_type=0"
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
# print("come to v2")
|
|
out = _torch_vacc.fuse_mla_moe_v2_allreduce_decode_layers_v2(
|
|
hidden_states,
|
|
residual,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
seq_lens,
|
|
sm_scale,
|
|
head_num,
|
|
flash_attention,
|
|
mlp_block_size_w13,
|
|
mlp_block_size_w2,
|
|
moe_block_size_w13,
|
|
moe_block_size_w2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
def fuse_mlp_qwen_int4(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
scale13: torch.Tensor,
|
|
scale2: torch.Tensor,
|
|
zero13: torch.Tensor,
|
|
zero2: torch.Tensor,
|
|
block13: List[int] | Tuple[int],
|
|
block2: List[int] | Tuple[int],
|
|
engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
assert engine_mode in [0, 1, 2]
|
|
return _torch_vacc.fuse_mlp_qwen_int4(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
scale13,
|
|
scale2,
|
|
zero13,
|
|
zero2,
|
|
block13,
|
|
block2,
|
|
engine_mode,
|
|
output_opt
|
|
)
|
|
|
|
def fuse_mlp_qwen_int4_reduce(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
scale13: torch.Tensor,
|
|
scale2: torch.Tensor,
|
|
zero13: torch.Tensor,
|
|
zero2: torch.Tensor,
|
|
block13: List[int] | Tuple[int],
|
|
block2: List[int] | Tuple[int],
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
scale13,
|
|
scale2,
|
|
zero13,
|
|
zero2,
|
|
block13,
|
|
block2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output_opt
|
|
)
|
|
|
|
def fuse_mlp_qwen_fp8(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
scale13: torch.Tensor,
|
|
scale2: torch.Tensor,
|
|
zero13: torch.Tensor,
|
|
zero2: torch.Tensor,
|
|
block13: List[int] | Tuple[int],
|
|
block2: List[int] | Tuple[int],
|
|
engine_mode: int = 0, #0:auto, 1:dlc, 2:dsp
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
assert engine_mode in [0, 1, 2]
|
|
return _torch_vacc.fuse_mlp_qwen_int4(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
scale13,
|
|
scale2,
|
|
zero13,
|
|
zero2,
|
|
block13,
|
|
block2,
|
|
engine_mode,
|
|
output_opt
|
|
)
|
|
|
|
def fuse_mlp_qwen_fp8_reduce(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
scale13: torch.Tensor,
|
|
scale2: torch.Tensor,
|
|
zero13: torch.Tensor,
|
|
zero2: torch.Tensor,
|
|
block13: List[int] | Tuple[int],
|
|
block2: List[int] | Tuple[int],
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
scale13,
|
|
scale2,
|
|
zero13,
|
|
zero2,
|
|
block13,
|
|
block2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output_opt
|
|
)
|
|
|
|
def fuse_mlp_qwen_fp16_bf16(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
return _torch_vacc.fuse_mlp_qwen_int4(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
(0,0),
|
|
(0,0),
|
|
0,
|
|
output_opt
|
|
)
|
|
|
|
def fuse_mlp_qwen_fp16_bf16_reduce(
|
|
hidden_states: torch.Tensor,
|
|
weight13: torch.Tensor,
|
|
weight2: torch.Tensor,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fuse_mlp_qwen_int4_reduce(
|
|
hidden_states,
|
|
weight13,
|
|
weight2,
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
torch.Tensor(),
|
|
(0,0),
|
|
(0,0),
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output_opt
|
|
)
|
|
def w4a8_block_int4_matmul(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
dequant_block: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
):
|
|
# TODO: CHECK
|
|
out = _torch_vacc.w4a8_block_int4_matmul(
|
|
input,
|
|
weight,
|
|
weight_scale,
|
|
dequant_block,
|
|
output_opt,
|
|
)
|
|
|
|
return out
|
|
def fuse_atten_qwen3(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_proj_weight: torch.Tensor,
|
|
qkv_proj_weight_scale: torch.Tensor,
|
|
qkv_proj_bias: torch.Tensor,
|
|
qkv_proj_qzeros: torch.Tensor,
|
|
q_layernorm_weight: torch.Tensor,
|
|
k_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale: torch.Tensor,
|
|
o_proj_bias: torch.Tensor,
|
|
o_proj_qzeros: torch.Tensor,
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
num_attention_heads: int,
|
|
num_key_value_heads: int,
|
|
flash_attention: bool,
|
|
is_decode: bool,
|
|
reduce_result: bool,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
res_opt: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_atten_qwen3(
|
|
hidden_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
qkv_proj_weight,
|
|
qkv_proj_weight_scale,
|
|
none2empty(qkv_proj_bias),
|
|
none2empty(qkv_proj_qzeros),
|
|
q_layernorm_weight,
|
|
k_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
o_proj_weight,
|
|
o_proj_weight_scale,
|
|
none2empty(o_proj_bias),
|
|
none2empty(o_proj_qzeros),
|
|
seq_lens,
|
|
sm_scale,
|
|
num_attention_heads,
|
|
num_key_value_heads,
|
|
flash_attention,
|
|
is_decode,
|
|
reduce_result,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output_opt,
|
|
res_opt,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
|
|
def fuse_atten_vit(
|
|
hidden_states: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
hidden_states_norm_bias: torch.Tensor,
|
|
qkv_proj_weight: torch.Tensor,
|
|
qkv_proj_bias: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
cos_cache: torch.Tensor,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_bias: torch.Tensor,
|
|
seq_lens: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
num_attention_heads: int,
|
|
flash_attention: bool,
|
|
reduce_result: bool,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
|
|
|
|
out = _torch_vacc.fuse_atten_vit(
|
|
hidden_states,
|
|
hidden_states_norm_weight,
|
|
hidden_states_norm_bias,
|
|
qkv_proj_weight,
|
|
none2empty(qkv_proj_bias),
|
|
sin_cache,
|
|
cos_cache,
|
|
o_proj_weight,
|
|
none2empty(o_proj_bias),
|
|
seq_lens,
|
|
sm_scale,
|
|
num_attention_heads,
|
|
flash_attention,
|
|
reduce_result,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output_opt
|
|
)
|
|
|
|
return out
|
|
|
|
def mrope_get_sin_cos(
|
|
cos_cache: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
mrope_section: List[int] | Tuple[int],
|
|
mrope_interleaved: bool
|
|
):
|
|
|
|
return _torch_vacc.mrope_get_sin_cos(
|
|
cos_cache,
|
|
sin_cache,
|
|
positions,
|
|
mrope_section,
|
|
mrope_interleaved
|
|
)
|
|
|
|
# NOTE: m <= 8, dsp version
|
|
# if m > 8, call w4a8_block_int4_matmul
|
|
def w4a8_block_int4_linear(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
weight_scale: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
pack_factor: int = 8,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
):
|
|
# TODO: CHECK
|
|
out = _torch_vacc.w4a8_block_int4_linear(
|
|
input,
|
|
weight,
|
|
input_scale,
|
|
weight_scale,
|
|
bias,
|
|
pack_factor,
|
|
output_opt,
|
|
)
|
|
|
|
return out
|
|
|
|
def fuse_atten_qwen2(
|
|
history_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_proj_weight: torch.Tensor,
|
|
qkv_proj_weight_scale_inv: torch.Tensor,
|
|
qkv_proj_bias: torch.Tensor,
|
|
qkv_proj_qzeros: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
o_proj_bias: torch.Tensor,
|
|
o_proj_qzeros: torch.Tensor,
|
|
seq_lens_num: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
num_attention_heads: int,
|
|
num_key_value_heads: int,
|
|
flash_attentiton: bool,
|
|
is_decode: bool,
|
|
reduce_result: bool,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
out_single = False
|
|
if residual is None:
|
|
out_single = True
|
|
residual = torch.Tensor()
|
|
|
|
out = _torch_vacc.fuse_atten_qwen2(
|
|
history_states,
|
|
residual,
|
|
hidden_states_norm_weight,
|
|
qkv_proj_weight,
|
|
qkv_proj_weight_scale_inv,
|
|
qkv_proj_bias,
|
|
qkv_proj_qzeros,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
o_proj_bias,
|
|
o_proj_qzeros,
|
|
seq_lens_num,
|
|
sm_scale,
|
|
num_attention_heads,
|
|
num_key_value_heads,
|
|
flash_attentiton,
|
|
is_decode,
|
|
reduce_result,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
if out_single:
|
|
return out[0]
|
|
return out
|
|
|
|
def qwen3_fuse_attention_moe_decode(
|
|
# attention
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_proj_weight: torch.Tensor,
|
|
qkv_proj_weight_scale_inv: torch.Tensor,
|
|
qkv_proj_bias: torch.Tensor,
|
|
qkv_proj_qzeros: torch.Tensor,
|
|
q_layernorm_weight: torch.Tensor,
|
|
k_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale_inv: torch.Tensor,
|
|
o_proj_bias: torch.Tensor,
|
|
o_proj_qzeros: torch.Tensor,
|
|
seq_lens_num: Tuple[int] | List[int],
|
|
sm_scale: float,
|
|
num_attention_heads: int,
|
|
num_key_value_heads: int,
|
|
flash_attentiton: bool,
|
|
is_decode: bool,
|
|
reduce_result: bool,
|
|
# moe
|
|
rms_weight: torch.Tensor,
|
|
moe_weight_13: torch.Tensor,
|
|
moe_weight_2: torch.Tensor,
|
|
moe_weight_13_dequat: torch.Tensor,
|
|
moe_weight_2_dequant: torch.Tensor,
|
|
gate_weight: torch.Tensor,
|
|
block_size_13: Tuple[int] | List[int],
|
|
block_size_2: Tuple[int] | List[int],
|
|
# dist
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
none2empty = lambda tensor: torch.Tensor() if tensor is None else tensor
|
|
|
|
return _torch_vacc.qwen3_fuse_attention_moe_decode(
|
|
hidden_states,
|
|
none2empty(residual),
|
|
hidden_states_norm_weight,
|
|
qkv_proj_weight,
|
|
qkv_proj_weight_scale_inv,
|
|
none2empty(qkv_proj_bias),
|
|
none2empty(qkv_proj_qzeros),
|
|
q_layernorm_weight,
|
|
k_layernorm_weight,
|
|
sin_cache,
|
|
cos_cache,
|
|
slot_mapping,
|
|
kv_cache,
|
|
block_tables,
|
|
block_group_size,
|
|
o_proj_weight,
|
|
o_proj_weight_scale_inv,
|
|
none2empty(o_proj_bias),
|
|
none2empty(o_proj_qzeros),
|
|
seq_lens_num,
|
|
sm_scale,
|
|
num_attention_heads,
|
|
num_key_value_heads,
|
|
flash_attentiton,
|
|
is_decode,
|
|
reduce_result,
|
|
rms_weight,
|
|
moe_weight_13,
|
|
moe_weight_2,
|
|
moe_weight_13_dequat,
|
|
moe_weight_2_dequant,
|
|
gate_weight,
|
|
block_size_13,
|
|
block_size_2,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
def fuse_mtp_stage0(
|
|
inputs_embeds: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
enorm_wegiht: torch.Tensor,
|
|
hnorm_wegint: torch.Tensor,
|
|
epsilon: float,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
|
|
return _torch_vacc.fuse_mtp_stage0(
|
|
inputs_embeds,
|
|
previous_hidden_states,
|
|
positions,
|
|
enorm_wegiht,
|
|
hnorm_wegint,
|
|
epsilon,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output,
|
|
)
|
|
|
|
def fuse_mtp_allreduce(
|
|
inputs_embeds: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
enorm_wegiht: torch.Tensor,
|
|
hnorm_wegint: torch.Tensor,
|
|
linear_weight: torch.Tensor,
|
|
epsilon: float,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
return _torch_vacc.fuse_mtp_allreduce(
|
|
inputs_embeds,
|
|
previous_hidden_states,
|
|
positions,
|
|
enorm_wegiht,
|
|
hnorm_wegint,
|
|
linear_weight,
|
|
epsilon,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
output,
|
|
)
|
|
|
|
def roll_out(
|
|
self: torch.Tensor,
|
|
shifts: List[int] | Tuple[int] | int,
|
|
dims: List[int] | Tuple[int] | int = [],
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
if isinstance(dims, int):
|
|
dims = [dims]
|
|
if isinstance(shifts, int):
|
|
shifts = [shifts]
|
|
return _torch_vacc.roll_out(
|
|
self,
|
|
shifts,
|
|
dims,
|
|
output,
|
|
)
|
|
|
|
|
|
def fused_experts_int4_prefill(
|
|
hidden_states: torch.Tensor,
|
|
w13_weight: torch.Tensor,
|
|
w2_weight: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
w13_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a13_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
w13_block_shape: Optional[List[int]] = None,
|
|
w2_block_shape: Optional[List[int]] = None,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
warning_message = "[fused_experts]:vacc only support fp8 weights now"
|
|
assert a13_scale is None and a2_scale is None, warning_message
|
|
assert (
|
|
w13_scale is not None and w2_scale is not None
|
|
), f"{warning_message}, but w13_weight_scale is {w13_scale}, w2_weight_scale is {w2_scale}"
|
|
assert (
|
|
w13_block_shape is not None
|
|
), f"{warning_message}, but block_shape is {w13_block_shape}"
|
|
# assert (
|
|
# not decode_with_batch or hidden_states.size(0) <= 4
|
|
# ), "[fused_experts]:vacc only support batch <= 4 when decode"
|
|
|
|
# topk weights dtype should be same with hidden_states
|
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
|
# vacc device use int32 for experts_id
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
|
|
# assert (
|
|
# block_size0 == block_size1
|
|
# ), "quant block shape now support size0 == size1"
|
|
return _torch_vacc.fused_experts_int4_prefill(
|
|
hidden_states,
|
|
w13_weight,
|
|
w2_weight,
|
|
topk_weights,
|
|
topk_ids,
|
|
w13_scale,
|
|
w2_scale,
|
|
a13_scale,
|
|
a2_scale,
|
|
w13_block_shape,
|
|
w2_block_shape,
|
|
output_opt,
|
|
)
|
|
|
|
def fuse_bge_embedding_stage1(
|
|
input_embeds: torch.Tensor,
|
|
positions_ids: torch.Tensor,
|
|
positions_embeddings_weight: torch.Tensor,
|
|
token_type_ids: torch.Tensor,
|
|
token_type_embeddings_weight: torch.Tensor,
|
|
layernorm_weight: torch.Tensor,
|
|
layernorm_bias: torch.Tensor,
|
|
epsilon: float,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return _torch_vacc.fuse_bge_embedding_stage1(
|
|
input_embeds,
|
|
positions_ids,
|
|
positions_embeddings_weight,
|
|
token_type_ids,
|
|
token_type_embeddings_weight,
|
|
layernorm_weight,
|
|
layernorm_bias,
|
|
epsilon,
|
|
output_opt,
|
|
)
|
|
|
|
def l2_norm(
|
|
input: torch.Tensor,
|
|
epsilon: float,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return _torch_vacc.l2_norm(
|
|
input,
|
|
epsilon,
|
|
output_opt,
|
|
)
|
|
|
|
class BERT_ATTN_STAGE(Enum):
|
|
# with reduce, for seqLen <= 2k
|
|
FullStage = 0
|
|
# without reduce, call reduce outer
|
|
AttnOutStage = 1
|
|
InterOutStage = 2
|
|
|
|
|
|
def fused_attn_bert_allreduce(
|
|
hidden_states: torch.Tensor,
|
|
qkv_weight: Optional[torch.Tensor] = None,
|
|
qkv_bias: Optional[torch.Tensor] = None,
|
|
self_weight: Optional[torch.Tensor] = None,
|
|
self_bias: Optional[torch.Tensor] = None,
|
|
self_norm_weight: Optional[torch.Tensor] = None,
|
|
self_norm_bias: Optional[torch.Tensor] = None,
|
|
intermediate_weight: Optional[torch.Tensor] = None,
|
|
intermediate_bias: Optional[torch.Tensor] = None,
|
|
output_weight: Optional[torch.Tensor] = None,
|
|
output_bias: Optional[torch.Tensor] = None,
|
|
output_norm_weight: Optional[torch.Tensor] = None,
|
|
output_norm_bias: Optional[torch.Tensor] = None,
|
|
dense_out: Optional[torch.Tensor] = None,
|
|
seqs: List[int] | Tuple[int] = [],
|
|
vnnlBertKind: BERT_ATTN_STAGE = BERT_ATTN_STAGE.FullStage,
|
|
sm_scale: float = 1.0,
|
|
num_q_heads: int = 1,
|
|
num_kv_heads: int = 1,
|
|
flash_attention: bool = False,
|
|
reduce_result: bool = False,
|
|
world_size: int = 1,
|
|
rank: int = 0,
|
|
group_id: int = 0,
|
|
dev_info: List[int] | Tuple[int] = [],
|
|
):
|
|
if 0 == len(dev_info):
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
|
return _torch_vacc.fused_attn_bert_allreduce(
|
|
hidden_states,
|
|
qkv_weight,
|
|
qkv_bias,
|
|
self_weight,
|
|
self_bias,
|
|
self_norm_weight,
|
|
self_norm_bias,
|
|
intermediate_weight,
|
|
intermediate_bias,
|
|
output_weight,
|
|
output_bias,
|
|
output_norm_weight,
|
|
output_norm_bias,
|
|
dense_out,
|
|
seqs,
|
|
vnnlBertKind.value,
|
|
sm_scale,
|
|
num_q_heads,
|
|
num_kv_heads,
|
|
flash_attention,
|
|
reduce_result,
|
|
world_size,
|
|
rank,
|
|
group_id,
|
|
dev_info,
|
|
)
|
|
|
|
def fuse_mlp_vision(
|
|
src: torch.Tensor,
|
|
weights_13: torch.Tensor,
|
|
weights_2: torch.Tensor,
|
|
weights_13_bias: Optional[torch.Tensor] = None,
|
|
weights_2_bias: Optional[torch.Tensor] = None,
|
|
act_type: int = 0,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if weights_13_bias is None:
|
|
weights_13_bias = torch.Tensor()
|
|
if weights_2_bias is None:
|
|
weights_2_bias = torch.Tensor()
|
|
return _torch_vacc.fuse_mlp_vision(
|
|
src,
|
|
weights_13,
|
|
weights_2,
|
|
weights_13_bias,
|
|
weights_2_bias,
|
|
act_type,
|
|
output_opt,
|
|
)
|
|
|
|
def patch_merger_vision(
|
|
src: torch.Tensor,
|
|
weights_13: torch.Tensor,
|
|
weights_2: torch.Tensor,
|
|
weights_13_bias: Optional[torch.Tensor] = None,
|
|
weights_2_bias: Optional[torch.Tensor] = None,
|
|
act_type: int = 0,
|
|
output_opt: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if weights_13_bias is None:
|
|
weights_13_bias = torch.Tensor()
|
|
if weights_2_bias is None:
|
|
weights_2_bias = torch.Tensor()
|
|
return _torch_vacc.patch_merger_vision(
|
|
src,
|
|
weights_13,
|
|
weights_2,
|
|
weights_13_bias,
|
|
weights_2_bias,
|
|
act_type,
|
|
output_opt,
|
|
) |