[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-06-05 13:13:14 -07:00
committed by GitHub
parent 35b65cf0ca
commit 0df6765c83
6 changed files with 230 additions and 79 deletions

View File

@@ -6,7 +6,7 @@ import triton # Added import
import triton.testing # Added import
from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
# --- Lambdas for Benchmarking ---
cutlass_lambda = lambda: cutlass_fused_experts(
cutlass_lambda = lambda: cutlass_fused_experts_fp8(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed
@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
print("Running correctness check...")
with torch.no_grad():
# Run CUTLASS version (requires transposed weights)
y_cutlass = cutlass_fused_experts(
y_cutlass = cutlass_fused_experts_fp8(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed

View File

@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts
if torch.cuda.get_device_capability() < (10, 0):
@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph(
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
)
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device=a.device,
num_experts=e,
intermediate_size_per_partition=n, # n
hidden_size=k,
) # k
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph(
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
ab_strides_13=ab_strides_13,
ab_strides_2=ab_strides_2,
c_strides_13=c_strides_13,
c_strides_2=c_strides_2,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=e,
device=a.device,
params=params,
apply_router_weight_on_input=False,
)
# Reference check: