[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
if is_cuda():
from sgl_kernel import scaled_fp4_quant
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# where the first intermediate_size_per_partition rows are w1, the next are w3
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
start : start + intermediate_size_per_partition, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
start : start + intermediate_size_per_partition, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size
start += intermediate_size_per_partition
# Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale.max(), requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
per_channel_quant=False,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4."""
@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_cutlass_moe
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
moe_runner_config = self.moe_runner_config
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return layer.forward(x, topk_output)
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
if self.enable_flashinfer_cutlass_moe:
assert (
@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
# Scale by routed_scaling_factor is fused into select_experts.
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens()
)
return output
return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.
return output
return StandardCombineInput(hidden_states=output)