[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com> Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from sgl_kernel.attention import (
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
|
||||
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_cutlass_w4a8_moe_mm_data(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor,
|
||||
output_permutation: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_mm(
|
||||
d: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
experts_offsets: torch.tensor,
|
||||
problem_sizes: torch.tensor,
|
||||
a_strides: torch.tensor,
|
||||
b_strides: torch.tensor,
|
||||
d_strides: torch.tensor,
|
||||
s_strides: torch.tensor,
|
||||
chunk_size: int = 128,
|
||||
topk: int = 8,
|
||||
):
|
||||
"""
|
||||
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
||||
|
||||
This function executes multiple GEMM operations in parallel, which is useful for
|
||||
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
||||
experts. The implementation leverages NVIDIA Hopper architecture features for
|
||||
optimal performance with quantized weights.
|
||||
|
||||
Args:
|
||||
d: Output matrices of shape [total_m, total_n]
|
||||
a: Activation matrices in FP8 (float_e4m3_t) format
|
||||
Each tensor should be of shape [total_m, K] in row-major layout
|
||||
b: Weight matrices in packed int4 format
|
||||
Each tensor should be of shape [E, N, K//2] in column-major layout
|
||||
where each byte contains two 4-bit integers
|
||||
a_scales: Scale factors for the inputs
|
||||
b_scales: Scale factors for the quantized weights
|
||||
Each tensor should be of shape [E, K//512, N*8]
|
||||
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
||||
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
||||
a_strides: Strides information for A matrices
|
||||
b_strides: Strides information for B matrices
|
||||
d_strides: Strides information for D matrices
|
||||
s_strides: Strides information for b_scales matrices
|
||||
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
||||
|
||||
Requirements:
|
||||
- All tensors must be on a CUDA device
|
||||
- Requires an NVIDIA Hopper GPU (H100)
|
||||
- A tensors must be in float8_e4m3fn format
|
||||
- B tensors must contain packed int4 values (stored as int8)
|
||||
|
||||
Note:
|
||||
The function computes: D = (A * (B * scales))
|
||||
for each group of tensors in parallel
|
||||
"""
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
a_scales,
|
||||
b_scales,
|
||||
experts_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk,
|
||||
)
|
||||
Reference in New Issue
Block a user