Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -400,7 +400,7 @@ def main(args: argparse.Namespace):
|
|||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = (
|
E = (
|
||||||
config.n_routed_experts + 1
|
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
|
||||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||||
else config.n_routed_experts
|
else config.n_routed_experts
|
||||||
)
|
)
|
||||||
@@ -408,7 +408,9 @@ def main(args: argparse.Namespace):
|
|||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||||
E = config.text_config.num_local_experts + 1
|
E = config.text_config.num_local_experts + (
|
||||||
|
0 if args.disable_shared_experts_fusion else 1
|
||||||
|
)
|
||||||
topk = config.text_config.num_experts_per_tok
|
topk = config.text_config.num_experts_per_tok
|
||||||
intermediate_size = config.text_config.intermediate_size
|
intermediate_size = config.text_config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
@@ -558,7 +560,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
)
|
)
|
||||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -568,6 +570,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--tune", action="store_true")
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
if self.use_grouped_topk:
|
if self.use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ def fused_moe_forward_native(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
@@ -71,6 +73,7 @@ def moe_forward_native(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -84,6 +87,7 @@ def moe_forward_native(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
torch_native=True,
|
torch_native=True,
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1540,6 +1540,7 @@ def fused_moe(
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
@@ -1609,6 +1610,7 @@ def fused_moe(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize,
|
renormalize,
|
||||||
topk_group,
|
topk_group,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
|
num_fused_shared_experts,
|
||||||
custom_routing_function,
|
custom_routing_function,
|
||||||
correction_bias,
|
correction_bias,
|
||||||
)
|
)
|
||||||
@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if self.use_grouped_topk:
|
if self.use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
|||||||
@@ -303,6 +303,7 @@ def select_experts(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
torch_native: bool = False,
|
torch_native: bool = False,
|
||||||
@@ -310,7 +311,6 @@ def select_experts(
|
|||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
):
|
):
|
||||||
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
|
|
||||||
|
|
||||||
router_logits, correction_bias = (
|
router_logits, correction_bias = (
|
||||||
expert_location_dispatch.transform_select_experts_inputs(
|
expert_location_dispatch.transform_select_experts_inputs(
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ class BlockInt8MoEMethod:
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -387,6 +388,7 @@ class BlockInt8MoEMethod:
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
|||||||
@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
|
|||||||
@@ -937,6 +937,7 @@ class Fp8MoEMethod:
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -957,6 +958,7 @@ class Fp8MoEMethod:
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
|||||||
@@ -341,6 +341,7 @@ class MoeWNA16Method:
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -362,6 +363,7 @@ class MoeWNA16Method:
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
|||||||
@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod:
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod:
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod:
|
|||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod:
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ global_server_args_dict = {
|
|||||||
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||||
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
||||||
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
|
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
|
||||||
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
|
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
||||||
"sampling_backend": ServerArgs.sampling_backend,
|
"sampling_backend": ServerArgs.sampling_backend,
|
||||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class ModelRunner:
|
|||||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||||
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
||||||
"num_fused_shared_experts": server_args.num_fused_shared_experts,
|
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
"sampling_backend": server_args.sampling_backend,
|
"sampling_backend": server_args.sampling_backend,
|
||||||
|
|||||||
@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.n_shared_experts = config.n_shared_experts
|
self.n_shared_experts = config.n_shared_experts
|
||||||
self.num_fused_shared_experts = global_server_args_dict[
|
self.num_fused_shared_experts = (
|
||||||
"num_fused_shared_experts"
|
0
|
||||||
]
|
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||||
|
else config.n_shared_experts
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||||
@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
def determine_num_fused_shared_experts(
|
def determine_num_fused_shared_experts(
|
||||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = global_server_args_dict[
|
self.num_fused_shared_experts = (
|
||||||
"num_fused_shared_experts"
|
0
|
||||||
]
|
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||||
|
else self.config.n_shared_experts
|
||||||
|
)
|
||||||
if self.num_fused_shared_experts > 0:
|
if self.num_fused_shared_experts > 0:
|
||||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
if (
|
if (
|
||||||
@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
or self.config.n_routed_experts != 256
|
or self.config.n_routed_experts != 256
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = 0
|
self.num_fused_shared_experts = 0
|
||||||
global_server_args_dict["num_fused_shared_experts"] = 0
|
global_server_args_dict["disable_shared_experts_fusion"] = 1
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
self.num_fused_shared_experts == self.tp_size
|
|
||||||
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
|
||||||
elif self.num_fused_shared_experts == 0:
|
elif self.num_fused_shared_experts == 0:
|
||||||
if (
|
if (
|
||||||
_is_cuda
|
_is_cuda
|
||||||
@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
and self.config.n_routed_experts == 256
|
and self.config.n_routed_experts == 256
|
||||||
and (not global_server_args_dict["enable_deepep_moe"])
|
and (not global_server_args_dict["enable_deepep_moe"])
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = self.tp_size
|
self.num_fused_shared_experts = self.config.n_shared_experts
|
||||||
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
|
global_server_args_dict["disable_shared_experts_fusion"] = 0
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||||
@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
if self.num_fused_shared_experts > 0:
|
if self.num_fused_shared_experts > 0:
|
||||||
|
assert self.num_fused_shared_experts == 1
|
||||||
weights_list = list(weights)
|
weights_list = list(weights)
|
||||||
weights_dict = dict(weights_list)
|
weights_dict = dict(weights_list)
|
||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
for moe_layer in tqdm(
|
for moe_layer in tqdm(
|
||||||
moe_layers,
|
moe_layers,
|
||||||
desc=f"Cloning {self.num_fused_shared_experts} "
|
desc=f"Cloning {self.num_fused_shared_experts} "
|
||||||
"replicas of the shared expert into MoE",
|
"shared expert into MoE",
|
||||||
):
|
):
|
||||||
for suffix in suffix_list:
|
for suffix in suffix_list:
|
||||||
shared_expert_weight_name = (
|
shared_expert_weight_name = (
|
||||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||||
)
|
)
|
||||||
for num_repeat in range(self.num_fused_shared_experts):
|
weights_list.append(
|
||||||
weights_list.append(
|
(
|
||||||
(
|
f"model.layers.{moe_layer}."
|
||||||
f"model.layers.{moe_layer}."
|
f"mlp.experts."
|
||||||
f"mlp.experts."
|
f"{self.config.n_routed_experts + 0}"
|
||||||
f"{self.config.n_routed_experts + num_repeat}"
|
f".{suffix}",
|
||||||
f".{suffix}",
|
weights_dict[shared_expert_weight_name],
|
||||||
weights_dict[shared_expert_weight_name],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
names_to_remove += [shared_expert_weight_name]
|
names_to_remove += [shared_expert_weight_name]
|
||||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class ServerArgs:
|
|||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
moe_dense_tp_size: Optional[int] = None
|
moe_dense_tp_size: Optional[int] = None
|
||||||
num_fused_shared_experts: int = 0
|
disable_shared_experts_fusion: bool = False
|
||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
mm_attention_backend: Optional[str] = None
|
mm_attention_backend: Optional[str] = None
|
||||||
@@ -1384,13 +1384,10 @@ class ServerArgs:
|
|||||||
default=ServerArgs.deepep_config,
|
default=ServerArgs.deepep_config,
|
||||||
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
|
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-fused-shared-experts",
|
"--disable-shared-experts-fusion",
|
||||||
type=int,
|
action="store_true",
|
||||||
default=0,
|
help="Disable shared experts fusion optimization for deepseek v3/r1.",
|
||||||
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
|
|
||||||
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-chunked-prefix-cache",
|
"--disable-chunked-prefix-cache",
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Calculate topk_excluding_share_expert_fusion from topk
|
// Calculate topk_excluding_share_expert_fusion from topk
|
||||||
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
|
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
|
||||||
|
|
||||||
// Cast pointers to type T:
|
// Cast pointers to type T:
|
||||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||||
@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl(
|
|||||||
|
|
||||||
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
||||||
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||||
|
int64_t expert_offset = 0;
|
||||||
// Use round-robin to select expert
|
|
||||||
int64_t expert_offset = thread_row % num_fused_shared_experts;
|
|
||||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||||
|
|
||||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||||
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||||
|
|
||||||
|
if (num_fused_shared_experts > 1) {
|
||||||
|
for (int i = 1; i < num_fused_shared_experts; ++i) {
|
||||||
|
++last_idx;
|
||||||
|
++expert_offset;
|
||||||
|
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||||
|
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||||
|
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ def moe_fused_gate(
|
|||||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||||
# num_fused_shared_experts: if > 0, the last expert will be replaced with a round-robin shared expert
|
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
|
||||||
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
|
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
|
||||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
bias,
|
bias,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
|
|||||||
(512, 16, 8, 16),
|
(512, 16, 8, 16),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1])
|
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
|
||||||
def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts):
|
def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts):
|
||||||
num_experts, num_expert_group, topk_group, topk = params
|
num_experts, num_expert_group, topk_group, topk = params
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_exp
|
|||||||
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
||||||
scores = tensor.clone()
|
scores = tensor.clone()
|
||||||
bias = torch.rand(num_experts).to(dtype).cuda()
|
bias = torch.rand(num_experts).to(dtype).cuda()
|
||||||
topk = topk + min(1, num_fused_shared_experts)
|
topk = topk + num_fused_shared_experts
|
||||||
|
|
||||||
output, indices = moe_fused_gate(
|
output, indices = moe_fused_gate(
|
||||||
tensor,
|
tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user