FlashInfer NVFP4 MoE with EP & 2-stream shared expert (#7327)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: alcanderian <alcanderian@gmail.com>
This commit is contained in:
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
try:
|
||||
from flashinfer import fp4_quantize as fp4_quantize
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
flashinfer_cutlass_fused_moe = None
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
layer.alpha = Parameter(
|
||||
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
||||
)
|
||||
layer.input_scale_inv = Parameter(
|
||||
(1 / input_scale_2).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scales = layer.weight_scale
|
||||
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
output_shape = [x_m, w_n]
|
||||
|
||||
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
|
||||
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
|
||||
|
||||
assert x_fp4.dtype == torch.uint8
|
||||
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
||||
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
)
|
||||
self.enable_flashinfer_moe = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
if self.enable_flashinfer_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
if self.enable_flashinfer_moe:
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
|
||||
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
layer.cutlass_moe_params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device,
|
||||
num_experts=layer.num_experts,
|
||||
num_experts=layer.num_experts, # global num experts
|
||||
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||
) # k
|
||||
|
||||
@property
|
||||
def load_up_proj_weight_first(self) -> bool:
|
||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||
return self.enable_flashinfer_moe
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
output = flashinfer_cutlass_fused_moe(
|
||||
x,
|
||||
topk_ids.to(torch.int),
|
||||
topk_weights,
|
||||
layer.w13_weight.view(torch.long),
|
||||
layer.w2_weight.view(torch.long),
|
||||
x.dtype,
|
||||
quant_scales=[
|
||||
layer.w13_input_scale_quant,
|
||||
layer.w13_blockscale_swizzled.view(torch.int32),
|
||||
layer.g1_alphas,
|
||||
layer.w2_input_scale_quant,
|
||||
layer.w2_blockscale_swizzled.view(torch.int32),
|
||||
layer.g2_alphas,
|
||||
],
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||
)
|
||||
return output[0]
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
return cutlass_moe_fp4(
|
||||
|
||||
Reference in New Issue
Block a user