Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)

This commit is contained in:
Qi Yuhang
2025-08-25 14:24:43 +08:00
committed by GitHub
parent a0b22f2f17
commit fda4792620
5 changed files with 104 additions and 134 deletions

View File

@@ -5,10 +5,6 @@ import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8_hopper_moe_mn_major,
)
def cdiv(a: int, b: int) -> int:
return -(a // -b)
@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool:
not (is_sm100_supported() or is_sm90_supported()),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
)
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("use_custom_kernel", [True, False])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
cc = torch.cuda.get_device_capability(None)[0]
if cc == 10 and use_custom_kernel:
return
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
device = "cuda"
alignment = 16
n_g = alignment * random.randint(1, 5) * 128
k_g = alignment * random.randint(1, 5) * 128
alignment = 128
n_g = random.randint(1, 64) * 128
k_g = random.randint(1, 64) * 128
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
a_original_tensors = []
a_tensors = []
b_tensors = []
a_scales_tensors = []
@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline_tensors = []
for g in range(num_experts):
m_g = alignment * random.randint(1, 64)
m_g = random.randint(1, 256)
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
b_g, b_scale = per_block_cast_to_fp8(
b
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
a_original_tensors.append(a)
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline = torch.mm(a, b)
baseline_tensors.append(baseline)
a_original_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=out_dtype
)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
a_scale_stack = torch.empty(
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
(expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32
(num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_experts):
# Matrix A is Row-Major.
a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = (
a_original_tensors[g]
)
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[
g
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1)
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
if cc == 9:
# For SM90, we need MN-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
a_scale_stack[
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
] = (a_scales_tensors[g].t().contiguous().view(-1))
b_scale_stack[g] = b_scales_tensors[g] # b_scale_stack[g] -- (k, n):(n, 1)
elif cc == 10:
# For SM100, we need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
a_scale_stack[
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
] = a_scales_tensors[g].view(-1)
b_scale_stack[g] = b_scales_tensors[
g
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
if cc == 10:
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
if use_custom_kernel:
# Replace a_stack, a_scale_stack with custom kernel output
a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
a_original_stack,
expert_offsets[:-1],
problem_sizes,
128,
expert_tokens_alignment=alignment,
)
# We need K-Major scale factor
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[
g
]
b_scale_stack[g] = b_scales_tensors[
g
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
diff = calc_diff(actual, baseline)
assert diff < 0.001
print(
f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
)