[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)

This commit is contained in:
Kaixi Hou
2025-08-22 12:19:45 -07:00
committed by GitHub
parent f556ac8bd8
commit e5638573c1
7 changed files with 420 additions and 13 deletions

View File

@@ -52,12 +52,14 @@ from sgl_kernel.gemm import (
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (

View File

@@ -295,6 +295,142 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
return output_tensor
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k = input_tensor.shape
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k),
input_global_scale,
input_offsets,
output_offsets,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def silu_and_mul_scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k_by_2 = input_tensor.shape
k = k_by_2 // 2
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k_by_2),
input_global_scale,
input_offsets,
output_offsets,
mask,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,