[2/N][refactor] split torchair from fused_moe (#2503)
### What this PR does / why we need it? After moved torchair related fused_moe section into torchair_fused_moe, split the torchair from the origin fused_moe ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main:ab9f2cfd19- vLLM version: v0.10.1.1 - vLLM main:2a97ffc33dSigned-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -50,7 +50,6 @@ from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_ascend_soc_version,
|
||||
@@ -76,8 +75,6 @@ def unified_fused_experts(
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
@@ -191,7 +188,6 @@ def fused_experts_with_mc2(
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
is_torchair: bool = False,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
quant_mode = 0
|
||||
@@ -199,8 +195,7 @@ def fused_experts_with_mc2(
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
or is_torchair)
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
@@ -246,11 +241,8 @@ def fused_experts_with_mc2(
|
||||
0:5]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
|
||||
@@ -324,9 +316,7 @@ def fused_experts_with_mc2(
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act, down_out_list)
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
@@ -930,9 +920,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
get_ascend_config()
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -1169,10 +1157,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and \
|
||||
self.torchair_graph_enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
@@ -1278,23 +1262,10 @@ class AscendFusedMoE(FusedMoE):
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
if self.enable_multistream_moe:
|
||||
if not self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
if hasattr(self.quant_method, "quant_method") and \
|
||||
isinstance(self.quant_method.quant_method,
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
) and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
|
||||
@@ -1339,16 +1310,15 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
if not self.torchair_graph_enabled:
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
if num_tokens < max_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
if num_tokens < max_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
@@ -1385,8 +1355,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
shared_experts=None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
|
||||
Reference in New Issue
Block a user