[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
from aiter.fused_moe import fused_moe
from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
logger = logging.getLogger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import QuarkConfig
class QuarkMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
class QuarkMoEMethod(FusedMoEMethodBase):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: QuarkConfig):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str,
) -> "QuarkMoEMethod":
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
return fused_moe(
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
),
doweight_stage1=False,
)
return StandardCombineInput(hidden_states=output)