Fix mooncake dispatcher (#11908)
This commit is contained in:
@@ -86,7 +86,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
|
|||||||
a2a_backend = get_moe_a2a_backend()
|
a2a_backend = get_moe_a2a_backend()
|
||||||
if a2a_backend.is_none():
|
if a2a_backend.is_none():
|
||||||
return StandardDispatcher(moe_runner_config)
|
return StandardDispatcher(moe_runner_config)
|
||||||
elif a2a_backend.is_deepep():
|
elif a2a_backend.is_deepep() or a2a_backend.is_mooncake():
|
||||||
return MaybeTboDeepEPDispatcher(
|
return MaybeTboDeepEPDispatcher(
|
||||||
group=get_tp_group().device_group,
|
group=get_tp_group().device_group,
|
||||||
router_topk=moe_runner_config.top_k,
|
router_topk=moe_runner_config.top_k,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class MooncakeDispatchOutput(NamedTuple):
|
|||||||
"""Mooncake EP dispatch output."""
|
"""Mooncake EP dispatch output."""
|
||||||
|
|
||||||
hidden_states: torch.Tensor
|
hidden_states: torch.Tensor
|
||||||
hidden_states_scale: torch.Tensor
|
hidden_states_scale: Optional[torch.Tensor]
|
||||||
topk_ids: torch.Tensor
|
topk_ids: torch.Tensor
|
||||||
topk_weights: torch.Tensor
|
topk_weights: torch.Tensor
|
||||||
masked_m: torch.Tensor
|
masked_m: torch.Tensor
|
||||||
@@ -205,8 +205,14 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
masked_m
|
masked_m
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states, hidden_states_scale = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states_scale = None
|
||||||
|
|
||||||
return MooncakeDispatchOutput(
|
return MooncakeDispatchOutput(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
hidden_states_scale,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
|
|||||||
Reference in New Issue
Block a user