Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
|
||||
@@ -21,6 +21,7 @@ def fused_moe_forward_native(
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
@@ -71,6 +73,7 @@ def moe_forward_native(
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -84,6 +87,7 @@ def moe_forward_native(
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
torch_native=True,
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -1540,6 +1540,7 @@ def fused_moe(
|
||||
activation: str = "silu",
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1609,6 +1610,7 @@ def fused_moe(
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
inplace: bool = True,
|
||||
@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
)
|
||||
@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module):
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module):
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
|
||||
@@ -303,6 +303,7 @@ def select_experts(
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
@@ -310,7 +311,6 @@ def select_experts(
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
|
||||
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
|
||||
@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
|
||||
@@ -367,6 +367,7 @@ class BlockInt8MoEMethod:
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -387,6 +388,7 @@ class BlockInt8MoEMethod:
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
|
||||
@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
correction_bias=correction_bias,
|
||||
|
||||
@@ -937,6 +937,7 @@ class Fp8MoEMethod:
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -957,6 +958,7 @@ class Fp8MoEMethod:
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
|
||||
@@ -341,6 +341,7 @@ class MoeWNA16Method:
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -362,6 +363,7 @@ class MoeWNA16Method:
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
|
||||
@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod:
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod:
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
|
||||
@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod:
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod:
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
|
||||
@@ -89,7 +89,7 @@ global_server_args_dict = {
|
||||
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
||||
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
|
||||
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
|
||||
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
||||
"sampling_backend": ServerArgs.sampling_backend,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
|
||||
@@ -204,7 +204,7 @@ class ModelRunner:
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
||||
"num_fused_shared_experts": server_args.num_fused_shared_experts,
|
||||
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
|
||||
@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.num_fused_shared_experts = global_server_args_dict[
|
||||
"num_fused_shared_experts"
|
||||
]
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else config.n_shared_experts
|
||||
)
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
layer_id=self.layer_id,
|
||||
@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
def determine_num_fused_shared_experts(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.num_fused_shared_experts = global_server_args_dict[
|
||||
"num_fused_shared_experts"
|
||||
]
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else self.config.n_shared_experts
|
||||
)
|
||||
if self.num_fused_shared_experts > 0:
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
or self.config.n_routed_experts != 256
|
||||
):
|
||||
self.num_fused_shared_experts = 0
|
||||
global_server_args_dict["num_fused_shared_experts"] = 0
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = 1
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.num_fused_shared_experts == self.tp_size
|
||||
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
||||
elif self.num_fused_shared_experts == 0:
|
||||
if (
|
||||
_is_cuda
|
||||
@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and self.config.n_routed_experts == 256
|
||||
and (not global_server_args_dict["enable_deepep_moe"])
|
||||
):
|
||||
self.num_fused_shared_experts = self.tp_size
|
||||
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
|
||||
self.num_fused_shared_experts = self.config.n_shared_experts
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.num_fused_shared_experts > 0:
|
||||
assert self.num_fused_shared_experts == 1
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config is not None:
|
||||
@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
for moe_layer in tqdm(
|
||||
moe_layers,
|
||||
desc=f"Cloning {self.num_fused_shared_experts} "
|
||||
"replicas of the shared expert into MoE",
|
||||
"shared expert into MoE",
|
||||
):
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
for num_repeat in range(self.num_fused_shared_experts):
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + num_repeat}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + 0}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ class ServerArgs:
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
moe_dense_tp_size: Optional[int] = None
|
||||
num_fused_shared_experts: int = 0
|
||||
disable_shared_experts_fusion: bool = False
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
disable_fast_image_processor: bool = False
|
||||
mm_attention_backend: Optional[str] = None
|
||||
@@ -1384,13 +1384,10 @@ class ServerArgs:
|
||||
default=ServerArgs.deepep_config,
|
||||
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-fused-shared-experts",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
|
||||
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
|
||||
"--disable-shared-experts-fusion",
|
||||
action="store_true",
|
||||
help="Disable shared experts fusion optimization for deepseek v3/r1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-chunked-prefix-cache",
|
||||
|
||||
Reference in New Issue
Block a user