[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: EPMoE,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // self.quant_config.group_size,
|
||||
intermediate_size_per_partition // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(num_experts, 3),
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.a_strides2 = torch.full(
|
||||
(num_experts, 3),
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: StandardTopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
# TODO(ch-wan): move it out of this class
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
Reference in New Issue
Block a user