[1/N] Introduce Mooncake Backend and Mooncake EP to Support Elastic EP (#10423)

Co-authored-by: Hank Han <hanhan7630@outlook.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Xun Sun
2025-10-15 10:40:54 +08:00
committed by GitHub
parent 74737b2863
commit a40229f6f8
13 changed files with 798 additions and 32 deletions

View File

@@ -592,6 +592,7 @@ class DeepseekV2MoE(nn.Module):
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_mooncake()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {}
),
@@ -622,7 +623,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if get_moe_a2a_backend().is_deepep():
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
@@ -651,7 +652,9 @@ class DeepseekV2MoE(nn.Module):
return_recv_hook=True,
)
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
)
def get_moe_weights(self):
return [
@@ -668,7 +671,7 @@ class DeepseekV2MoE(nn.Module):
use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
if not self._enable_deepep_moe:
if not self._enable_a2a_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None