[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)

Co-authored-by: Cheng Wan <cwan@x.ai>
This commit is contained in:
azhurkevich
2025-08-04 03:10:02 -07:00
committed by GitHub
parent 36fc9260a2
commit 915140fd18
8 changed files with 504 additions and 117 deletions

View File

@@ -1,13 +1,15 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from __future__ import annotations
import importlib.util
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
@@ -39,6 +42,11 @@ if is_cuda():
try:
from flashinfer import mm_fp4 as fp4_gemm
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
enable_flashinfer_fp4_gemm = True
except ImportError:
@@ -47,6 +55,9 @@ except ImportError:
else:
fp4_gemm = None
enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None
try:
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FlashInferFP4MoE):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return ModelOptNvFp4FusedMoEMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
return None
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_cutlass_moe = False
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
"""Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
def create_weights(
self,
@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" dynamic quantization is not supported."
)
# TODO(ch-wan): check if this is needed
layer.num_experts = num_experts
layer.num_local_experts = num_experts
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
layer.local_num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
layer.num_local_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
layer.num_local_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.quant_config.group_size,
@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
def swizzle_blockscale(self, scale: torch.Tensor):
assert scale.dtype == torch.float8_e4m3fn
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
else swizzled_scale.reshape(B, M, K)
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def prepare_static_weights_for_kernel(
self,
# args_dequant,
# args,
gemm1_weights,
gemm2_weights,
gemm1_scales_linear_fp4_bytes,
gemm2_scales_linear_fp4_bytes,
hidden_size,
intermediate_size,
num_experts,
):
from flashinfer import (
RoutingMethodType,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
# GEMM 1
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
)
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
)
return (
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP4 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
# GEMM 1 scale processing
if not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
@@ -880,73 +1020,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
if self.enable_flashinfer_cutlass_moe:
# Calculate input scales based on strategy
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
# Create shared parameters
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
)
assert (
layer.w13_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
del layer.w13_weight_scale
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2
if self.enable_flashinfer_cutlass_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale
layer.g2_alphas = Parameter(
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
layer.w2_input_scale_quant = Parameter(
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)
assert (
layer.w2_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
# Validate weight scales
for name, weight_scale in [
("w13", layer.w13_weight_scale),
("w2", layer.w2_weight_scale),
]:
assert (
weight_scale.shape[2] % 16 == 0
), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
assert (
weight_scale.dtype == torch.float8_e4m3fn
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
layer.w2_blockscale_swizzled = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
del layer.w2_weight_scale
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Weight processing based on strategy
if (
self.enable_flashinfer_trtllm_moe
and reorder_rows_for_gated_act_gemm is not None
and shuffle_matrix_sf_a is not None
):
# FlashInfer TRTLLM processing - handles both w13 and w2
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = self.prepare_static_weights_for_kernel(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
device = layer.w13_weight.device
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k
# Set flashinfer parameters
layer.gemm1_weights_fp4_shuffled = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del (
layer.w2_weight,
layer.w2_weight_scale,
layer.w13_weight,
layer.w13_weight_scale,
)
print("Applied flashinfer weight processing for both w13 and w2")
else:
# CUTLASS processing - handle w13 and w2 separately
# Process w13 weights
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# Process w2 weights
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print("Applied weight processing for both w13 and w2")
# Set up CUTLASS MoE parameters
device = layer.w13_weight.device
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k
@property
def load_up_proj_weight_first(self) -> bool:
@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return layer.forward(x, topk_output)
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids, _ = topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output = cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,