[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
from sgl_kernel import (
|
||||
scaled_fp4_grouped_quant,
|
||||
scaled_fp4_quant,
|
||||
silu_and_mul,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
@@ -166,5 +171,83 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
def test_quantize_to_fp4_grouped():
|
||||
torch.manual_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
l, m, k = 2, 512, 2048
|
||||
x = torch.randn((l, m, k), dtype=torch.bfloat16)
|
||||
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
|
||||
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
output, output_scales = scaled_fp4_grouped_quant(
|
||||
x,
|
||||
x_sf_global,
|
||||
)
|
||||
# output in logical (m, k, l), but its physical layout is (l, m, k).
|
||||
# So permute first to (l, m, k).
|
||||
output = output.permute(2, 0, 1)
|
||||
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
|
||||
# So permute first to (l, rm, rk, 32, 4, 4).
|
||||
padded_m = ((m + 128 - 1) // 128) * 128
|
||||
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
|
||||
for i in range(l):
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
|
||||
torch.testing.assert_close(a_fp4, output[i])
|
||||
torch.testing.assert_close(
|
||||
a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)])
|
||||
def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
|
||||
torch.manual_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
l, m, k = shape
|
||||
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
|
||||
max_m = 8
|
||||
assert max_m <= m
|
||||
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
|
||||
|
||||
ref_y = silu_and_mul(x)
|
||||
tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)
|
||||
y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
ref_output, ref_output_scales = scaled_fp4_grouped_quant(
|
||||
ref_y,
|
||||
y_sf_global,
|
||||
)
|
||||
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
|
||||
x,
|
||||
y_sf_global,
|
||||
mask,
|
||||
)
|
||||
|
||||
# output in logical (m, k, l), but its physical layout is (l, m, k).
|
||||
# So permute first to (l, m, k).
|
||||
output = output.permute(2, 0, 1)
|
||||
ref_output = ref_output.permute(2, 0, 1)
|
||||
|
||||
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
|
||||
# So permute first to (l, rm, rk, 32, 4, 4).
|
||||
padded_m = ((m + 128 - 1) // 128) * 128
|
||||
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
|
||||
ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view(
|
||||
l, padded_m, -1
|
||||
)
|
||||
|
||||
for i in range(l):
|
||||
torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]])
|
||||
# We need to recover the swizzled scales to linear layout before applying mask slice.
|
||||
scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k)
|
||||
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
|
||||
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user