[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
)
from sglang.srt.utils import is_cuda, is_hip
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe(
output = fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)