[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf (#9556)

This commit is contained in:
Kaixi Hou
2025-08-29 17:17:03 -07:00
committed by GitHub
parent ff9b561817
commit 5c34b4f1c7
7 changed files with 297 additions and 61 deletions

View File

@@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
def test_quantize_to_fp4_grouped():
@pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)])
def test_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")
l, m, k = 2, 512, 2048
l, m, k = shape
x = torch.randn((l, m, k), dtype=torch.bfloat16)
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
output, output_scales = scaled_fp4_grouped_quant(
x,
x_sf_global,
mask,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
@@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped():
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
for i in range(l):
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
torch.testing.assert_close(a_fp4, output[i])
torch.testing.assert_close(
a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float)
)
torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]])
# Recover swizzled scales to linear layout and drop padded values, so
# no extra checks on padding are needed.
scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")
l, m, k = shape
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
max_m = 8
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
@@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
ref_output, ref_output_scales = scaled_fp4_grouped_quant(
ref_y,
y_sf_global,
mask,
)
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
x,