[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import get_rm_router_logits_state
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
@@ -198,7 +198,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
|
||||
All2All and share the same finalize method.
|
||||
All2All and share the same finalize method.
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
@@ -277,9 +277,24 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter.
|
||||
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
Uses `max_tokens_across_dp` from forward_context for padding alignment.
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
|
||||
we gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
The communication and calculation process is as follows (AG, AR and RS
|
||||
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
|
||||
|
||||
Attn → TP AR → DP AG → MoE → DP RS → TP AR
|
||||
|
||||
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
|
||||
the above process becomes:
|
||||
|
||||
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
|
||||
|
||||
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
|
||||
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
|
||||
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
@@ -289,6 +304,42 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -301,7 +352,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
@@ -323,7 +373,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
@@ -331,6 +380,36 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
|
||||
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
|
||||
result of shared experts.
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
|
||||
Reference in New Issue
Block a user