[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
StandardDispatchOutput,
|
||||
CombineInput,
|
||||
)
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return fused_marlin_moe(
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
).to(orig_dtype)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
Reference in New Issue
Block a user