Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)
This commit is contained in:
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
|
||||
if not is_sm100_supported():
|
||||
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
||||
w1_scale = w1_scale.contiguous()
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
if not is_sm100_supported():
|
||||
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
||||
w2_scale = w2_scale.contiguous()
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c2,
|
||||
|
||||
@@ -8,6 +8,15 @@ from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
|
||||
|
||||
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def get_model_config(tp_size: int):
|
||||
@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
|
||||
# --- Input Data ---
|
||||
# Use bf16/fp16 for input activation based on model config
|
||||
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
|
||||
x = torch.randn((batch_size, H), device="cuda", dtype=dtype)
|
||||
# --- Weights (Generate in higher precision, then convert to FP8) ---
|
||||
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
|
||||
w1_hp = (
|
||||
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
|
||||
)
|
||||
w2_hp = (
|
||||
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
|
||||
+ 0.00001
|
||||
)
|
||||
w1_hp = torch.randn((E, I, H), device="cuda", dtype=torch.float32)
|
||||
w2_hp = torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32)
|
||||
|
||||
w1 = to_fp8(w1_hp)
|
||||
w2 = to_fp8(w2_hp)
|
||||
@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
)
|
||||
|
||||
# Note: Triton expects non-transposed weights
|
||||
moe_config = MoeRunnerConfig(inplace=False)
|
||||
triton_lambda = lambda: fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False,
|
||||
activation="silu", # Assuming SiLU activation common in MoEs
|
||||
moe_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
w1, # Original shape
|
||||
w2, # Original shape
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False, # Important: Use False to get output tensor
|
||||
activation="silu",
|
||||
moe_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Ensure outputs are same dtype for comparison
|
||||
y_cutlass = y_cutlass.to(dtype)
|
||||
y_triton = y_triton.to(dtype)
|
||||
|
||||
abs_error = torch.abs(y_cutlass - y_triton)
|
||||
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
|
||||
|
||||
max_abs_err = abs_error.max().item()
|
||||
max_rel_err = rel_error.max().item()
|
||||
|
||||
print("y_cutlass:", y_cutlass[:, :10])
|
||||
print("y_triton:", y_triton[:, :10])
|
||||
print(f"Max absolute error: {max_abs_err:.6f}")
|
||||
print(f"Max relative error: {max_rel_err:.6f}")
|
||||
diff = calc_diff(y_cutlass, y_triton)
|
||||
print(f"Diff: {diff:.6f}")
|
||||
|
||||
# Tolerance might need adjustment based on FP8 specifics and kernel differences
|
||||
# FP8 comparisons often require higher tolerance than FP16/BF16
|
||||
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
|
||||
assert diff < 1e-4, f"Diff too high! {diff}"
|
||||
print("Correctness check passed.")
|
||||
|
||||
|
||||
@@ -264,7 +255,21 @@ if __name__ == "__main__":
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
|
||||
default=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
], # Adjusted default
|
||||
help="List of batch sizes to test",
|
||||
)
|
||||
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
||||
|
||||
Reference in New Issue
Block a user