[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer: EPMoE,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size * 2,
intermediate_size_per_partition * 2,
hidden_size // 2,
dtype=torch.int8,
),
@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.empty(
num_experts,
hidden_size,
intermediate_size // 2,
intermediate_size_per_partition // 2,
dtype=torch.int8,
),
requires_grad=False,
@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size,
2 * intermediate_size_per_partition,
hidden_size // self.quant_config.group_size,
dtype=torch.float32,
),
@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.zeros(
num_experts,
hidden_size,
intermediate_size // self.quant_config.group_size,
intermediate_size_per_partition // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
self.c_strides1 = torch.full(
(num_experts, 3),
2 * intermediate_size,
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(num_experts, 3),
intermediate_size,
intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: EPMoE,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
# TODO(ch-wan): move it out of this class
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale,
layer.w2_input_scale,
)
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return StandardCombineInput(hidden_states=output)