Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
This commit is contained in:
@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
||||
from sglang.srt.layers.moe import (
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
|
||||
output_dtype = x.dtype
|
||||
x_sf = None
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
|
||||
|
||||
# Quantize before comm, swizzle after.
|
||||
if x.shape[0] > 0:
|
||||
x, x_sf = fp4_quantize(
|
||||
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
|
||||
)
|
||||
else:
|
||||
x_col = x.shape[1]
|
||||
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
|
||||
x_sf = torch.zeros(
|
||||
0, x_col // 16, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
|
||||
output = flashinfer_cutlass_fused_moe(
|
||||
x,
|
||||
topk_ids.to(torch.int),
|
||||
topk_weights,
|
||||
layer.w13_weight.view(torch.long),
|
||||
layer.w2_weight.view(torch.long),
|
||||
x.dtype,
|
||||
input=x,
|
||||
token_selected_experts=topk_ids.to(torch.int),
|
||||
token_final_scales=topk_weights,
|
||||
fc1_expert_weights=layer.w13_weight.view(torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.view(torch.long),
|
||||
output_dtype=output_dtype,
|
||||
input_sf=x_sf,
|
||||
quant_scales=[
|
||||
layer.w13_input_scale_quant,
|
||||
layer.w13_blockscale_swizzled.view(torch.int32),
|
||||
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)[0]
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
output, global_output = get_local_dp_buffer(), output
|
||||
get_tp_group().reduce_scatterv(
|
||||
global_output, output=output, sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
return output
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
Reference in New Issue
Block a user