[Feature] Support fine-grained shared expert overlap (#5482)

Fine-grained control over shared expert overlap to prevent resource
contention.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2026-01-17 11:53:22 +08:00
committed by GitHub
parent 48e10de8c9
commit 22f253142a
9 changed files with 203 additions and 130 deletions

View File

@@ -163,6 +163,7 @@ class TestMoECommMethod(TestBase):
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather" "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
) )
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
@patch("torch.npu.current_stream", MagicMock())
def test_fused_experts_method(self, mock_unified_apply_mlp, def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context): mock_get_forward_context):

View File

@@ -116,26 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
mock_dispatch.assert_called_once() mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0 self.assertEqual(output.group_list_type, 0) # group_list_type == 0
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
self.dispatcher.with_quant = False
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self): def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128) hidden_states = torch.randn(10, 128)

View File

@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Optional from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@@ -47,7 +50,20 @@ from vllm_ascend.quantization.w8a8_dynamic import \
from vllm_ascend.utils import (AscendDeviceType, enable_sp, from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz, get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled, npu_stream_switch, shared_expert_dp_enabled,
shared_experts_calculation_stream) shared_experts_calculation_stream, vllm_version_is)
@dataclass
class FusedMoEResult:
routed_out: torch.Tensor
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
@dataclass
class FusedMoEEvents:
before_routed_experts: torch.npu.Event
before_dispatch: torch.npu.Event | None = field(default=None)
before_combine: torch.npu.Event | None = field(default=None)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -90,7 +106,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -137,7 +152,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask", None))
@@ -268,13 +282,13 @@ class AscendFusedMoE(FusedMoE):
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states) final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl( # type: ignore[override]
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
assert self.quant_method is not None assert self.quant_method is not None
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
forward_context = get_forward_context() forward_context = get_forward_context()
# Load balancing for token distribution among experts in dummy_run # Load balancing for token distribution among experts in dummy_run
@@ -359,9 +373,6 @@ class AscendFusedMoE(FusedMoE):
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
shared_experts=None,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy, log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num, global_redundant_expert_num=self.global_redundant_expert_num,
@@ -380,6 +391,13 @@ class AscendFusedMoE(FusedMoE):
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata) context_metadata=context_metadata)
if return_with_event:
return FusedMoEResult(
routed_out=routed_out,
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
before_combine_evt=fused_experts_results.before_combine_evt)
else:
# The vLLM FusedMoE forward_impl does not return events.
return routed_out return routed_out
@@ -407,6 +425,74 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self._gate = gate self._gate = gate
# Wrap the quant_method's process_weights_after_loading to validate that
# splitting shared expert computation (gate_up projection + activation,
# then down projection) yields identical results to integrated
# computation after weight loading.
original_process_weights = self.quant_method.process_weights_after_loading
@wraps(original_process_weights)
def wrapped_process_weights(*args, **kwargs):
result = original_process_weights(*args, **kwargs)
self._validate_shared_expert_consistency()
return result
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj(
hidden_states) # type: ignore
return shared_gate_up
def _shared_experts_part2(self, hidden_states: torch.Tensor,
shared_gate_up: torch.Tensor):
shared_act = self._shared_experts.act_fn(
shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(
shared_act) # type: ignore
# Qwen3-Next specific gating mechanism
if hasattr(self._shared_experts, "expert_gate") and \
self._shared_experts.expert_gate is not None:
if vllm_version_is('0.13.0'):
# TODO(jianzs): remove this branch after vLLM new version is
# released
gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore
else:
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
shared_out = F.sigmoid(gate_out) * shared_out
return shared_out
def _validate_shared_expert_consistency(self):
"""Validate that split shared expert computation matches integrated
computation."""
test_input = torch.rand(
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
integrated_out = self._shared_experts(test_input)
part1_out = self._shared_experts_part1(test_input)
split_out = self._shared_experts_part2(test_input, part1_out)
if not torch.allclose(integrated_out, split_out):
diff = (integrated_out - split_out).abs()
logger.error(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.error(f"Max absolute difference: {diff.max().item()}")
logger.error("Integrated output - sum: %s, norm: %s",
integrated_out.sum().item(),
integrated_out.norm().item())
logger.error("Split output - sum: %s, norm: %s",
split_out.sum().item(),
split_out.norm().item())
raise ValueError(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.info_once(
"SharedFusedMoE shared experts split computation matches the "
"integrated computation.")
@property @property
def gate(self) -> Optional[torch.nn.Module]: def gate(self) -> Optional[torch.nn.Module]:
return self._gate if self.use_overlapped else None return self._gate if self.use_overlapped else None
@@ -434,44 +520,67 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
) )
return shared_out, fused_out return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor, def _forward_shared_experts(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): fused_moe_evts: FusedMoEEvents):
shared_out = None
if not self.multistream_overlap_gate: def maybe_wait_event(evt: torch.npu.Event | None):
# Make sure the shared experts stream begins after hidden_states are ready. if evt is not None:
if self.multistream_overlap_shared_expert: torch.npu.current_stream().wait_event(evt)
shared_experts_calculation_stream(
).wait_stream( # type: ignore with npu_stream_switch(shared_experts_calculation_stream(),
torch.npu.current_stream())
with npu_stream_switch(
shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert): enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts. # Ensure the shared experts wait for hidden_states to be ready.
shared_out = self._shared_experts(hidden_states) torch.npu.current_stream().wait_event(
else: fused_moe_evts.before_routed_experts)
set_flash_common3_context(shared_experts=self._shared_experts) # Execute the gate projection and activation concurrently with the
# dispatch communication.
maybe_wait_event(fused_moe_evts.before_dispatch)
part1_out = self._shared_experts_part1(hidden_states)
# Execute the down projection concurrently with the combine
# communication.
maybe_wait_event(fused_moe_evts.before_combine)
shared_out = self._shared_experts_part2(hidden_states, part1_out)
routed_out = AscendFusedMoE.forward_impl( # Make sure the default stream waits for the shared experts stream to
self, # finish.
hidden_states=hidden_states,
router_logits=router_logits,
)
if not self.multistream_overlap_gate:
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert: if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream( torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream()) shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled(): and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
else: return shared_out
def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if self.multistream_overlap_gate:
set_flash_common3_context(shared_experts=self._shared_experts)
before_routed_experts = torch.npu.current_stream().record_event()
fused_moe_results = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
return_with_event=True,
)
routed_out = fused_moe_results.routed_out
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context() fc3_context = get_flash_common3_context()
assert fc3_context is not None assert fc3_context is not None
shared_out = fc3_context.shared_out shared_out = fc3_context.shared_out
else:
shared_out = self._forward_shared_experts(
hidden_states,
FusedMoEEvents(
before_routed_experts=before_routed_experts,
before_dispatch=fused_moe_results.before_dispatch_evt,
before_combine=fused_moe_results.before_combine_evt,
))
return shared_out, routed_out return shared_out, routed_out

View File

@@ -17,7 +17,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import Dict, Optional
import torch import torch
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config):
@dataclass @dataclass
class FusedExpertsResult: class FusedExpertsResult:
routed_out: torch.Tensor routed_out: torch.Tensor
# This field is for shared experts and should be set by the MoE
# communication method that supports shared experts in parallel with routed
# experts.
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
# For dynamic_eplb # For dynamic_eplb
group_list_type: int | None = None group_list_type: int | None = None
expert_tokens: torch.Tensor | None = None expert_tokens: torch.Tensor | None = None
@@ -108,10 +113,6 @@ class MoECommMethod(ABC):
w2_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None, w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance # For load balance
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
need_trans: bool = False, need_trans: bool = False,
@@ -126,6 +127,7 @@ class MoECommMethod(ABC):
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context" assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event()
dispatch_results = self.token_dispatcher.token_dispatch( dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
@@ -134,9 +136,6 @@ class MoECommMethod(ABC):
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=self.moe_config. global_redundant_expert_num=self.moe_config.
global_redundant_expert_num, global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask, mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8, with_quant=use_int8_w8a8 or use_int4_w4a8,
@@ -162,12 +161,15 @@ class MoECommMethod(ABC):
need_trans=need_trans, need_trans=need_trans,
dynamic_eplb=dynamic_eplb) dynamic_eplb=dynamic_eplb)
before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine( combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, hidden_states=mlp_output,
context_metadata=dispatch_results.context_metadata) context_metadata=dispatch_results.context_metadata)
return FusedExpertsResult( return FusedExpertsResult(
routed_out=combine_results.routed_out, routed_out=combine_results.routed_out,
before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type, group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list) expert_tokens=dispatch_results.group_list)
@@ -284,10 +286,6 @@ class FusedMC2CommImpl(MoECommMethod):
w2_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None, w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance # For load balance
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
need_trans: bool = False, need_trans: bool = False,

View File

@@ -151,8 +151,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
padded_hidden_states_shape = hidden_states.shape
if not (self.replace_allreduce or self.enable_shared_expert_dp): if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
@@ -162,6 +162,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
if self.tp_size > 1: if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states, split_hidden_states = torch.tensor_split(hidden_states,
@@ -174,7 +175,9 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states} context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape
}
return hidden_states, router_logits, None, context_metadata return hidden_states, router_logits, None, context_metadata
@@ -190,14 +193,25 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
""" """
assert context_metadata is not None
split_hidden_states = context_metadata["split_hidden_states"]
if not (self.enable_shared_expert_dp or self.replace_allreduce): if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1: if self.tp_size > 1:
assert context_metadata is not None
# Cannot reuse `split_hidden_states` from prepare phase as it
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata[
"padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
split_hidden_states = torch.tensor_split(
gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states, dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group) self.moe_config.tp_group.device_group)
hidden_states = torch.cat(split_hidden_states, dim=0) hidden_states = gathered_hidden_states
if self.num_tokens < hidden_states.shape[0]: if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens] hidden_states = hidden_states[:self.num_tokens]
@@ -249,7 +263,6 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
forward_context = get_forward_context() forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask mc2_mask = forward_context.mc2_mask
if self.tp_size > 1: if self.tp_size > 1:
@@ -257,6 +270,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
mc2_mask = split_mc2_mask[self.tp_rank] mc2_mask = split_mc2_mask[self.tp_rank]
padded_hidden_states_shape = hidden_states.shape
if not self.replace_allreduce: if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens target_pad_length = forward_context.padded_num_tokens
@@ -268,6 +282,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
# Slice across TP ranks # Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp: if self.tp_size > 1 and not self.enable_shared_expert_dp:
@@ -280,7 +295,9 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states} context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape,
}
return hidden_states, router_logits, mc2_mask, context_metadata return hidden_states, router_logits, mc2_mask, context_metadata

View File

@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional from typing import Optional
import torch import torch
import torch_npu import torch_npu
@@ -82,9 +82,6 @@ class MoETokenDispatcher(ABC):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
@@ -193,9 +190,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
@@ -226,12 +220,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"ep_recv_counts": ep_recv_counts, "ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts, "tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine, "assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"expand_scales": expand_scales "expand_scales": expand_scales
} }
group_list_type = 0 group_list_type = 0
return TokenDispatchResult(hidden_states=expand_x, return TokenDispatchResult(hidden_states=expand_x,
dynamic_scale=dynamic_scale, dynamic_scale=dynamic_scale,
group_list=expert_token_nums, group_list=expert_token_nums,
@@ -297,7 +289,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
return TokenCombineResult(routed_out=combined_output) return TokenCombineResult(routed_out=combined_output, )
class TokenDispatcherWithAllGather(MoETokenDispatcher): class TokenDispatcherWithAllGather(MoETokenDispatcher):
@@ -319,9 +311,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
@@ -442,9 +431,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,

View File

@@ -204,9 +204,6 @@ class AscendW4A16FusedMoEMethod:
enable_force_load_balance: bool = True, enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[
@@ -229,8 +226,7 @@ class AscendW4A16FusedMoEMethod:
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(hidden_states=x,
hidden_states=x,
w1=layer.w13_weight_packed, w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed, w2=layer.w2_weight_packed,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
@@ -242,11 +238,9 @@ class AscendW4A16FusedMoEMethod:
use_int4_w4a16=True, use_int4_w4a16=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get(
"mc2_mask", None))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight: if self.transpose_weight:

View File

@@ -341,9 +341,6 @@ class AscendW4A8DynamicFusedMoEMethod:
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[
@@ -390,9 +387,6 @@ class AscendW4A8DynamicFusedMoEMethod:
use_int4_w4a8=True, use_int4_w4a8=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask", None))

View File

@@ -190,9 +190,6 @@ class AscendW8A8DynamicFusedMoEMethod:
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
pertoken_scale: Optional[Any] = None, pertoken_scale: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
@@ -280,9 +277,6 @@ class AscendW8A8DynamicFusedMoEMethod:
use_int8_w8a8=True, use_int8_w8a8=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask", None))
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None: