[8/N] MoE Refactor: deprecate EPMoE (#11211)

This commit is contained in:
Cheng Wan
2025-10-07 21:51:41 -07:00
committed by GitHub
parent 7c3f07dbcb
commit 3c06b673af
19 changed files with 526 additions and 1808 deletions

View File

@@ -31,8 +31,8 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and get_moe_a2a_backend().is_deepep()
):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else:
moe_runner_backend = MoeRunnerBackend.TRITON
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
else:
# TODO(cwan): refactor other backends
pass
def apply(
self,
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
if self.runner.runner_backend.is_deep_gemm():
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
if self.block_quant:
block_shape = self.quant_config.weight_block_size
w13_scale = layer.w13_weight_scale_inv
w2_scale = layer.w2_weight_scale_inv
else:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
block_shape = [scale_block_size, scale_block_size]
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
w13_scale = (
layer.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_scale_k, dim=2)
)
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
w2_scale = (
layer.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_scale_k, dim=2)
)
quant_info = DeepGemmMoeQuantInfo(
w13_weight=w13_weight,
w2_weight=w2_weight,
use_fp8=True,
w13_scale=w13_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
elif self.runner.runner_backend.is_triton():
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
else:
raise NotImplementedError(
"Unsupported runner backend: %s" % self.runner.runner_backend
)
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits(

View File

@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: EPMoE,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: EPMoE,
layer: Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
x,
layer.w13_weight,
layer.w2_weight,
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,