[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf (#9556)
This commit is contained in:
@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
def scaled_fp4_grouped_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant(
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
|
||||
output_offsets = torch.arange(
|
||||
0,
|
||||
(l + 1) * padded_m,
|
||||
step=padded_m,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k),
|
||||
input_global_scale,
|
||||
input_offsets,
|
||||
output_offsets,
|
||||
mask,
|
||||
use_silu_and_mul=False,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant(
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
|
||||
output_offsets = torch.arange(
|
||||
0,
|
||||
(l + 1) * padded_m,
|
||||
step=padded_m,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k_by_2),
|
||||
input_global_scale,
|
||||
input_offsets,
|
||||
output_offsets,
|
||||
mask,
|
||||
use_silu_and_mul=True,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
|
||||
Reference in New Issue
Block a user