[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-06-05 13:13:14 -07:00
committed by GitHub
parent 35b65cf0ca
commit 0df6765c83
6 changed files with 230 additions and 79 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional
import torch
@@ -184,13 +184,9 @@ def cutlass_fp4_group_mm(
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
out_dtype,
device,
params: Dict[str, Any],
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
@@ -220,10 +216,10 @@ def cutlass_fp4_group_mm(
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
params["ab_strides"],
params["c_strides"],
params["problem_sizes"],
params["expert_offsets"],
params["blockscale_offsets"],
)
return c.to(dtype=out_dtype)