[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
output_tensor = torch.empty(
|
||||
output_tensor_shape,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert (
|
||||
input_tensor.ndim == 2
|
||||
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
|
||||
if expert_map is not None:
|
||||
(m, k) = input_tensor.shape
|
||||
output_tensor_shape = (m * topk, k)
|
||||
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
import os
|
||||
|
||||
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
|
||||
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
|
||||
)
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
output,
|
||||
output_scales,
|
||||
input_tensor,
|
||||
input_global_scale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
Reference in New Issue
Block a user