[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf (#9556)

This commit is contained in:
Kaixi Hou
2025-08-29 17:17:03 -07:00
committed by GitHub
parent ff9b561817
commit 5c34b4f1c7
7 changed files with 297 additions and 61 deletions

View File

@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant(
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k),
input_global_scale,
input_offsets,
output_offsets,
mask,
use_silu_and_mul=False,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant(
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k_by_2),
input_global_scale,
input_offsets,
output_offsets,
mask,
use_silu_and_mul=True,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.