[Refactor][MoE] Reuse vLLM's all_reduce logic (#5189)
### What this PR does / why we need it?
Move all_reduce logic to AscendFusedMoE.forward, reuse vLLM's logic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -22,7 +22,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@@ -470,8 +469,4 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||
dim=0)
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user