[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)
### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.
Plan / Roadmap
1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.
Other notes
* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.
- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -5,11 +5,12 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -54,6 +55,8 @@ def set_ascend_forward_context(
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
@@ -66,6 +69,7 @@ def set_ascend_forward_context(
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp):
|
||||
forward_context = get_forward_context()
|
||||
forward_context.moe_comm_method = moe_comm_method
|
||||
forward_context.with_prefill = with_prefill
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
@@ -97,16 +101,17 @@ def set_ascend_forward_context(
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tp_group().world_size
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
|
||||
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type)
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user