[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization import QuarkConfig
|
||||
|
||||
|
||||
class QuarkMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str,
|
||||
) -> "QuarkMoEMethod":
|
||||
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
moe_runner_config = self.moe_runner_config
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
return fused_moe(
|
||||
output = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
Reference in New Issue
Block a user