[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-06-02 13:48:03 -07:00
committed by GitHub
parent df7f61ee7d
commit eb38c7d1ca
12 changed files with 1677 additions and 22 deletions

View File

@@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert (
input_tensor.ndim == 2
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
if expert_map is not None:
(m, k) = input_tensor.shape
output_tensor_shape = (m * topk, k)
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
import os
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales