Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

This commit is contained in:
Trevor Morris
2025-08-15 22:08:11 -07:00
committed by GitHub
parent 87dab54824
commit eff4eb3fdd
16 changed files with 360 additions and 52 deletions

View File

@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = x.dtype
x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if x.shape[0] > 0:
x, x_sf = fp4_quantize(
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
)
else:
x_col = x.shape[1]
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
)
x_sf = nvfp4_block_scale_interleave(x_sf)
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
topk_weights,
layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long),
x.dtype,
input=x,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight.view(torch.long),
fc2_expert_weights=layer.w2_weight.view(torch.long),
output_dtype=output_dtype,
input_sf=x_sf,
quant_scales=[
layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32),
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[0]
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens()
)
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4