Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)

This commit is contained in:
Qi Yuhang
2025-08-25 14:24:43 +08:00
committed by GitHub
parent a0b22f2f17
commit fda4792620
5 changed files with 104 additions and 134 deletions

View File

@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
if not is_sm100_supported():
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
w1_scale = w1_scale.contiguous()
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
silu_and_mul(c1, intermediate)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
if not is_sm100_supported():
a2_scale = per_group_transpose(a2_scale, expert_offsets)
w2_scale = w2_scale.contiguous()
fp8_blockwise_scaled_grouped_mm(
c2,

View File

@@ -8,6 +8,15 @@ from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def get_model_config(tp_size: int):
@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
# --- Input Data ---
# Use bf16/fp16 for input activation based on model config
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
x = torch.randn((batch_size, H), device="cuda", dtype=dtype)
# --- Weights (Generate in higher precision, then convert to FP8) ---
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
w1_hp = (
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
)
w2_hp = (
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
+ 0.00001
)
w1_hp = torch.randn((E, I, H), device="cuda", dtype=torch.float32)
w2_hp = torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32)
w1 = to_fp8(w1_hp)
w2 = to_fp8(w2_hp)
@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
)
# Note: Triton expects non-transposed weights
moe_config = MoeRunnerConfig(inplace=False)
triton_lambda = lambda: fused_experts(
x,
w1,
w2,
(topk_weights, topk_ids, "dummy"),
inplace=False,
activation="silu", # Assuming SiLU activation common in MoEs
moe_config,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
w1, # Original shape
w2, # Original shape
(topk_weights, topk_ids, "dummy"),
inplace=False, # Important: Use False to get output tensor
activation="silu",
moe_config,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
# Ensure outputs are same dtype for comparison
y_cutlass = y_cutlass.to(dtype)
y_triton = y_triton.to(dtype)
abs_error = torch.abs(y_cutlass - y_triton)
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
max_abs_err = abs_error.max().item()
max_rel_err = rel_error.max().item()
print("y_cutlass:", y_cutlass[:, :10])
print("y_triton:", y_triton[:, :10])
print(f"Max absolute error: {max_abs_err:.6f}")
print(f"Max relative error: {max_rel_err:.6f}")
diff = calc_diff(y_cutlass, y_triton)
print(f"Diff: {diff:.6f}")
# Tolerance might need adjustment based on FP8 specifics and kernel differences
# FP8 comparisons often require higher tolerance than FP16/BF16
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
assert diff < 1e-4, f"Diff too high! {diff}"
print("Correctness check passed.")
@@ -264,7 +255,21 @@ if __name__ == "__main__":
"--batch-sizes",
type=int,
nargs="+",
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
default=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
], # Adjusted default
help="List of batch sizes to test",
)
parser.add_argument("--check", action="store_true", help="Enable check mode")