[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -138,10 +140,12 @@ def prepare_moe_input(
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
@@ -150,3 +154,54 @@ def prepare_moe_input(
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp4_group_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
out_dtype,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- ab_strides/c_strides: Strides for the a/b tensors between rows.
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_fp4.shape[0]
|
||||
n = b_fp4.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
|
||||
c,
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
return c.to(dtype=out_dtype)
|
||||
|
||||
Reference in New Issue
Block a user