[Refactor][MoE] Reuse vLLM's all_reduce logic (#5189)
### What this PR does / why we need it?
Move all_reduce logic to AscendFusedMoE.forward, reuse vLLM's logic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -169,15 +169,12 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
self.assertEqual(final_result.shape[0], 2)
|
self.assertEqual(final_result.shape[0], 2)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||||
@patch(
|
|
||||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
|
||||||
)
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||||
return_value=False)
|
return_value=False)
|
||||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,
|
||||||
mock_tp_all_reduce, mock_get_dp_group):
|
mock_get_dp_group):
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.max_tokens_across_dp = 6
|
mock_context.max_tokens_across_dp = 6
|
||||||
@@ -222,7 +219,5 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(result.shape[0], 3)
|
self.assertEqual(result.shape[0], 3)
|
||||||
|
|
||||||
# Test with TP all-reduce
|
|
||||||
mock_tp_all_reduce.return_value = result
|
|
||||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||||
self.assertEqual(result_with_tp.shape[0], 3)
|
self.assertEqual(result_with_tp.shape[0], 3)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import os.path
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@@ -292,32 +291,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
og_hidden_states = hidden_states.shape[-1]
|
|
||||||
if self.hidden_size != og_hidden_states:
|
|
||||||
hidden_states = F.pad(
|
|
||||||
hidden_states,
|
|
||||||
(0, self.hidden_size - og_hidden_states),
|
|
||||||
mode="constant",
|
|
||||||
value=0.0,
|
|
||||||
)
|
|
||||||
if self.shared_experts is None:
|
|
||||||
fused_output = torch.ops.vllm.moe_forward(hidden_states,
|
|
||||||
router_logits,
|
|
||||||
self.layer_name)
|
|
||||||
return fused_output[..., :og_hidden_states]
|
|
||||||
else:
|
|
||||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
|
||||||
hidden_states, router_logits, self.layer_name)
|
|
||||||
return (
|
|
||||||
shared_output[..., :og_hidden_states],
|
|
||||||
fused_output[..., :og_hidden_states],
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@@ -470,8 +469,4 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||||
dim=0)
|
dim=0)
|
||||||
if reduce_results and (self.moe_config.tp_size > 1
|
|
||||||
or self.moe_config.ep_size > 1):
|
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user