[v0.11.0] [Bugfix] [MoE]fix error in deepseek when using allgather (#3827)
### What this PR does / why we need it? After refactoring vllm_ascend/models and FusedMoE, we are unable to pass `gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result in error when running deepseek v3/r1 with allgather. Hence, this pr removes `gate` related computations from FusedMoE module in eager/aclgraph mode. ### Does this PR introduce _any_ user-facing change? `rm_router_logits` is deprecated in eager/aclgraph. ### How was this patch tested? e2e & ut Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -191,13 +191,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
|
|
||||||
# Mock the gate function for rm_router_logits=False case
|
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||||
mock_gate = MagicMock()
|
|
||||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
|
||||||
|
|
||||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
|
||||||
router_logits,
|
|
||||||
gate=mock_gate)
|
|
||||||
|
|
||||||
# After all-gather with DP=2, should double the batch size
|
# After all-gather with DP=2, should double the batch size
|
||||||
self.assertEqual(h_out.shape[0], 12)
|
self.assertEqual(h_out.shape[0], 12)
|
||||||
@@ -258,14 +252,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
|
|
||||||
# Mock gate for router logits recomputation
|
|
||||||
mock_gate = MagicMock()
|
|
||||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
|
||||||
|
|
||||||
# Run prepare
|
# Run prepare
|
||||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||||
router_logits,
|
|
||||||
gate=mock_gate)
|
|
||||||
|
|
||||||
# Should be global tensor: [7, 8] and [7, 2]
|
# Should be global tensor: [7, 8] and [7, 2]
|
||||||
self.assertEqual(h_out.shape, (7, 8))
|
self.assertEqual(h_out.shape, (7, 8))
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, None)
|
hidden_states, router_logits, False, False)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, None)
|
hidden_states, router_logits, False, False)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, None)
|
hidden_states, router_logits, False, False)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
from vllm_ascend.utils import enable_sp
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEPrepareAndFinalize(ABC):
|
class FusedMoEPrepareAndFinalize(ABC):
|
||||||
@@ -43,31 +43,26 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
|
|
||||||
self.rm_router_logits = get_rm_router_logits_state(
|
|
||||||
self.moe_config.ep_size, self.moe_config.dp_size,
|
|
||||||
is_deepseek_v3_r1)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Prepare tensors before MoE computation. May involve:
|
Prepare tensors before MoE computation. May involve:
|
||||||
- Padding to align communication boundaries
|
- Padding to align communication boundaries
|
||||||
- Slicing across tensor-parallel ranks
|
- Slicing across tensor-parallel ranks
|
||||||
- Broadcasting across data-parallel ranks
|
- Broadcasting across data-parallel ranks
|
||||||
- Recomputing router logits if needed
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
Tuple of:
|
||||||
@@ -116,12 +111,13 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||||
@@ -214,12 +210,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||||
@@ -307,12 +304,13 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
AllGather hidden_states and router_logits to form global tensors.
|
AllGather hidden_states and router_logits to form global tensors.
|
||||||
@@ -325,7 +323,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||||
enable_shared_expert_dp,
|
enable_shared_expert_dp,
|
||||||
replace_allreduce, gate)
|
replace_allreduce)
|
||||||
|
|
||||||
def _prepare_with_ep_group(
|
def _prepare_with_ep_group(
|
||||||
self,
|
self,
|
||||||
@@ -340,12 +338,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
return hidden_states, router_logits, None
|
return hidden_states, router_logits, None
|
||||||
|
|
||||||
def _prepare_with_dp_group(
|
def _prepare_with_dp_group(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch max token count across DP group from forward context.
|
1. Fetch max token count across DP group from forward context.
|
||||||
@@ -365,18 +363,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
hidden_states = nn.functional.pad(hidden_states,
|
hidden_states = nn.functional.pad(hidden_states,
|
||||||
(0, 0, 0, pad_size))
|
(0, 0, 0, pad_size))
|
||||||
if not self.rm_router_logits:
|
router_logits = nn.functional.pad(router_logits,
|
||||||
router_logits = nn.functional.pad(router_logits,
|
(0, 0, 0, pad_size))
|
||||||
(0, 0, 0, pad_size))
|
|
||||||
|
|
||||||
# All-gather across DP group
|
# All-gather across DP group
|
||||||
hidden_states = self.moe_config.dp_group.all_gather(
|
hidden_states = self.moe_config.dp_group.all_gather(
|
||||||
hidden_states, 0)
|
hidden_states, 0)
|
||||||
if self.rm_router_logits:
|
router_logits = self.moe_config.dp_group.all_gather(
|
||||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
router_logits, 0)
|
||||||
else:
|
|
||||||
router_logits = self.moe_config.dp_group.all_gather(
|
|
||||||
router_logits, 0)
|
|
||||||
|
|
||||||
return hidden_states, router_logits, None
|
return hidden_states, router_logits, None
|
||||||
|
|
||||||
@@ -472,12 +466,13 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
|||||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch cumulative token boundaries from forward context.
|
1. Fetch cumulative token boundaries from forward context.
|
||||||
@@ -493,11 +488,8 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
|||||||
).dp_metadata.cu_tokens_across_sp(1)
|
).dp_metadata.cu_tokens_across_sp(1)
|
||||||
hidden_states = self._naive_multicast(hidden_states,
|
hidden_states = self._naive_multicast(hidden_states,
|
||||||
self.cu_tokens_across_dp_cpu)
|
self.cu_tokens_across_dp_cpu)
|
||||||
if self.rm_router_logits:
|
router_logits = self._naive_multicast(router_logits,
|
||||||
router_logits, _ = gate(hidden_states)
|
self.cu_tokens_across_dp_cpu)
|
||||||
else:
|
|
||||||
router_logits = self._naive_multicast(
|
|
||||||
router_logits, self.cu_tokens_across_dp_cpu)
|
|
||||||
|
|
||||||
return hidden_states, router_logits, None
|
return hidden_states, router_logits, None
|
||||||
|
|
||||||
|
|||||||
@@ -63,15 +63,16 @@ class MoECommMethod(ABC):
|
|||||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare(self,
|
def prepare(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
router_logits: torch.Tensor,
|
||||||
replace_allreduce: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
replace_allreduce: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||||
hidden_states, router_logits, enable_shared_expert_dp,
|
hidden_states, router_logits, enable_shared_expert_dp,
|
||||||
replace_allreduce, gate)
|
replace_allreduce)
|
||||||
self.mc2_mask = mc2_mask
|
self.mc2_mask = mc2_mask
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
|||||||
@@ -48,12 +48,12 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
||||||
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
|
||||||
|
get_rm_router_logits_state,
|
||||||
|
npu_stream_switch, npu_wait_tensor,
|
||||||
super_kernel)
|
super_kernel)
|
||||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||||
get_all_reduce_merge_state,
|
get_ascend_soc_version, is_310p,
|
||||||
get_ascend_soc_version,
|
|
||||||
get_rm_router_logits_state, is_310p,
|
|
||||||
is_hierarchical_communication_enabled)
|
is_hierarchical_communication_enabled)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||||
|
|
||||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||||
@@ -241,3 +243,33 @@ def torchair_ops_patch():
|
|||||||
|
|
||||||
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ttanzhiqiang): rm_router_logits
|
||||||
|
# dp>1 will trigger
|
||||||
|
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||||
|
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||||
|
is_deepseek_v3_r1: bool):
|
||||||
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||||
|
# only supports deepseek v3/r1
|
||||||
|
if dp_size > 1:
|
||||||
|
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||||
|
and is_deepseek_v3_r1):
|
||||||
|
return True
|
||||||
|
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ttanzhiqiang): all_reduce merge
|
||||||
|
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||||
|
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||||
|
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
||||||
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||||
|
# only supports deepseek v3/r1
|
||||||
|
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||||
|
and is_deepseek_v3_r1):
|
||||||
|
return True
|
||||||
|
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|||||||
@@ -520,36 +520,6 @@ class ProfileExecuteDuration:
|
|||||||
return durations
|
return durations
|
||||||
|
|
||||||
|
|
||||||
# TODO(ttanzhiqiang): rm_router_logits
|
|
||||||
# dp>1 will trigger
|
|
||||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
|
||||||
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
|
||||||
is_deepseek_v3_r1: bool):
|
|
||||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
|
||||||
# only supports deepseek v3/r1
|
|
||||||
if dp_size > 1:
|
|
||||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
|
||||||
and is_deepseek_v3_r1):
|
|
||||||
return True
|
|
||||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(ttanzhiqiang): all_reduce merge
|
|
||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
|
||||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
|
||||||
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
|
||||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
|
||||||
# only supports deepseek v3/r1
|
|
||||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
|
||||||
and is_deepseek_v3_r1):
|
|
||||||
return True
|
|
||||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||||
"""Register Ascend CustomOP
|
"""Register Ascend CustomOP
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user