[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import triton # Added import
|
||||
import triton.testing # Added import
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
|
||||
|
||||
# --- Lambdas for Benchmarking ---
|
||||
cutlass_lambda = lambda: cutlass_fused_experts(
|
||||
cutlass_lambda = lambda: cutlass_fused_experts_fp8(
|
||||
x,
|
||||
w1.transpose(1, 2), # Transposed
|
||||
w2.transpose(1, 2), # Transposed
|
||||
@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
print("Running correctness check...")
|
||||
with torch.no_grad():
|
||||
# Run CUTLASS version (requires transposed weights)
|
||||
y_cutlass = cutlass_fused_experts(
|
||||
y_cutlass = cutlass_fused_experts_fp8(
|
||||
x,
|
||||
w1.transpose(1, 2), # Transposed
|
||||
w2.transpose(1, 2), # Transposed
|
||||
|
||||
@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
||||
)
|
||||
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
||||
params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device=a.device,
|
||||
num_experts=e,
|
||||
intermediate_size_per_partition=n, # n
|
||||
hidden_size=k,
|
||||
) # k
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
w2_fp4=w2_q,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=(1 / w2_gs),
|
||||
ab_strides_13=ab_strides_13,
|
||||
ab_strides_2=ab_strides_2,
|
||||
c_strides_13=c_strides_13,
|
||||
c_strides_2=c_strides_2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=a.device,
|
||||
params=params,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
|
||||
Reference in New Issue
Block a user