[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
@@ -59,8 +61,10 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip()
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
intermediate_size_per_partition_after_pad = intermediate_size
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if _is_sm100_supported:
if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 256
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64
intermediate_size_per_partition, 64
)
elif has_triton_kernels:
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block
intermediate_size_per_partition, mxfp4_block
)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
and layer.w2_weight.shape[2]
== self.intermediate_size_per_partition // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
== self.intermediate_size_per_partition // sf_block_size
)
assert (
layer.w13_weight_bias.dim() == 2
and layer.w13_weight_bias.shape[0] == self.num_experts
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
and layer.w13_weight_bias.shape[1]
== self.intermediate_size_per_partition * 2
)
assert (
layer.w2_weight_bias.dim() == 2
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
2 * self.intermediate_size_per_partition,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
self.intermediate_size_per_partition // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k,
None, # n_group # TODO: support n_group
None, # topk_group # TODO: support topk_group
self.intermediate_size, # padded to multiple of 256
self.intermediate_size_per_partition, # padded to multiple of 256
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts
None,
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
1, # routing_method_type, renormalize
True, # do finalize
)[0]
return trtllm_gen_output
return StandardCombineInput(hidden_states=trtllm_gen_output)
if self.use_triton_kernels:
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias:
return self.triton_kernel_moe_with_bias_forward(
output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w1_pcg=self.w13_precision_config,
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
moe_runner_config=moe_runner_config,
)
else:
return self.triton_kernel_moe_forward(
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_weight_bias=layer.w13_weight_bias,
w2_weight_bias=layer.w2_weight_bias,
)
return self.runner.run(dispatch_output, quant_info)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
if _is_hip:
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
return fused_moe(
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
return StandardCombineInput(hidden_states=output)