[Feature] Support fine-grained shared expert overlap (#5482)
Fine-grained control over shared expert overlap to prevent resource
contention.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -163,6 +163,7 @@ class TestMoECommMethod(TestBase):
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||
@patch("torch.npu.current_stream", MagicMock())
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
|
||||
@@ -116,26 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@@ -47,7 +50,20 @@ from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
|
||||
get_ascend_device_type, maybe_trans_nz,
|
||||
npu_stream_switch, shared_expert_dp_enabled,
|
||||
shared_experts_calculation_stream)
|
||||
shared_experts_calculation_stream, vllm_version_is)
|
||||
|
||||
@dataclass
|
||||
class FusedMoEResult:
|
||||
routed_out: torch.Tensor
|
||||
before_dispatch_evt: torch.npu.Event | None = None
|
||||
before_combine_evt: torch.npu.Event | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEEvents:
|
||||
before_routed_experts: torch.npu.Event
|
||||
before_dispatch: torch.npu.Event | None = field(default=None)
|
||||
before_combine: torch.npu.Event | None = field(default=None)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@@ -90,7 +106,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
@@ -137,7 +152,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
@@ -268,13 +282,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward_impl( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
|
||||
assert self.quant_method is not None
|
||||
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
# Load balancing for token distribution among experts in dummy_run
|
||||
@@ -359,9 +373,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
shared_experts=None,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
@@ -380,7 +391,14 @@ class AscendFusedMoE(FusedMoE):
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
return routed_out
|
||||
if return_with_event:
|
||||
return FusedMoEResult(
|
||||
routed_out=routed_out,
|
||||
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
|
||||
before_combine_evt=fused_experts_results.before_combine_evt)
|
||||
else:
|
||||
# The vLLM FusedMoE forward_impl does not return events.
|
||||
return routed_out
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
@@ -407,6 +425,74 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
self._gate = gate
|
||||
|
||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||
# splitting shared expert computation (gate_up projection + activation,
|
||||
# then down projection) yields identical results to integrated
|
||||
# computation after weight loading.
|
||||
original_process_weights = self.quant_method.process_weights_after_loading
|
||||
|
||||
@wraps(original_process_weights)
|
||||
def wrapped_process_weights(*args, **kwargs):
|
||||
result = original_process_weights(*args, **kwargs)
|
||||
self._validate_shared_expert_consistency()
|
||||
return result
|
||||
|
||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||
|
||||
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
||||
hidden_states) # type: ignore
|
||||
return shared_gate_up
|
||||
|
||||
def _shared_experts_part2(self, hidden_states: torch.Tensor,
|
||||
shared_gate_up: torch.Tensor):
|
||||
shared_act = self._shared_experts.act_fn(
|
||||
shared_gate_up) # type: ignore
|
||||
shared_out, _ = self._shared_experts.down_proj(
|
||||
shared_act) # type: ignore
|
||||
|
||||
# Qwen3-Next specific gating mechanism
|
||||
if hasattr(self._shared_experts, "expert_gate") and \
|
||||
self._shared_experts.expert_gate is not None:
|
||||
if vllm_version_is('0.13.0'):
|
||||
# TODO(jianzs): remove this branch after vLLM new version is
|
||||
# released
|
||||
gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
||||
else:
|
||||
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
||||
shared_out = F.sigmoid(gate_out) * shared_out
|
||||
return shared_out
|
||||
|
||||
def _validate_shared_expert_consistency(self):
|
||||
"""Validate that split shared expert computation matches integrated
|
||||
computation."""
|
||||
test_input = torch.rand(
|
||||
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
|
||||
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
|
||||
|
||||
integrated_out = self._shared_experts(test_input)
|
||||
part1_out = self._shared_experts_part1(test_input)
|
||||
split_out = self._shared_experts_part2(test_input, part1_out)
|
||||
|
||||
if not torch.allclose(integrated_out, split_out):
|
||||
diff = (integrated_out - split_out).abs()
|
||||
logger.error(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.error(f"Max absolute difference: {diff.max().item()}")
|
||||
logger.error("Integrated output - sum: %s, norm: %s",
|
||||
integrated_out.sum().item(),
|
||||
integrated_out.norm().item())
|
||||
logger.error("Split output - sum: %s, norm: %s",
|
||||
split_out.sum().item(),
|
||||
split_out.norm().item())
|
||||
raise ValueError(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.info_once(
|
||||
"SharedFusedMoE shared experts split computation matches the "
|
||||
"integrated computation.")
|
||||
|
||||
@property
|
||||
def gate(self) -> Optional[torch.nn.Module]:
|
||||
return self._gate if self.use_overlapped else None
|
||||
@@ -434,44 +520,67 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
shared_out = None
|
||||
if not self.multistream_overlap_gate:
|
||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
shared_experts_calculation_stream(
|
||||
).wait_stream( # type: ignore
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(
|
||||
shared_experts_calculation_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Use a separate stream to run shared experts.
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
else:
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
||||
fused_moe_evts: FusedMoEEvents):
|
||||
|
||||
def maybe_wait_event(evt: torch.npu.Event | None):
|
||||
if evt is not None:
|
||||
torch.npu.current_stream().wait_event(evt)
|
||||
|
||||
with npu_stream_switch(shared_experts_calculation_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Ensure the shared experts wait for hidden_states to be ready.
|
||||
torch.npu.current_stream().wait_event(
|
||||
fused_moe_evts.before_routed_experts)
|
||||
# Execute the gate projection and activation concurrently with the
|
||||
# dispatch communication.
|
||||
maybe_wait_event(fused_moe_evts.before_dispatch)
|
||||
part1_out = self._shared_experts_part1(hidden_states)
|
||||
# Execute the down projection concurrently with the combine
|
||||
# communication.
|
||||
maybe_wait_event(fused_moe_evts.before_combine)
|
||||
shared_out = self._shared_experts_part2(hidden_states, part1_out)
|
||||
|
||||
# Make sure the default stream waits for the shared experts stream to
|
||||
# finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_calculation_stream())
|
||||
|
||||
# NOTE: This is exactly the opposite of
|
||||
# `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
if self.multistream_overlap_gate:
|
||||
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||
|
||||
routed_out = AscendFusedMoE.forward_impl(
|
||||
before_routed_experts = torch.npu.current_stream().record_event()
|
||||
fused_moe_results = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
return_with_event=True,
|
||||
)
|
||||
routed_out = fused_moe_results.routed_out
|
||||
|
||||
if not self.multistream_overlap_gate:
|
||||
# Make sure the default stream waits for the shared experts stream to finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_calculation_stream())
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
else:
|
||||
if self.multistream_overlap_gate:
|
||||
fc3_context = get_flash_common3_context()
|
||||
assert fc3_context is not None
|
||||
shared_out = fc3_context.shared_out
|
||||
else:
|
||||
shared_out = self._forward_shared_experts(
|
||||
hidden_states,
|
||||
FusedMoEEvents(
|
||||
before_routed_experts=before_routed_experts,
|
||||
before_dispatch=fused_moe_results.before_dispatch_evt,
|
||||
before_combine=fused_moe_results.before_combine_evt,
|
||||
))
|
||||
|
||||
return shared_out, routed_out
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config):
|
||||
@dataclass
|
||||
class FusedExpertsResult:
|
||||
routed_out: torch.Tensor
|
||||
# This field is for shared experts and should be set by the MoE
|
||||
# communication method that supports shared experts in parallel with routed
|
||||
# experts.
|
||||
before_dispatch_evt: torch.npu.Event | None = None
|
||||
before_combine_evt: torch.npu.Event | None = None
|
||||
# For dynamic_eplb
|
||||
group_list_type: int | None = None
|
||||
expert_tokens: torch.Tensor | None = None
|
||||
@@ -108,10 +113,6 @@ class MoECommMethod(ABC):
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
@@ -126,6 +127,7 @@ class MoECommMethod(ABC):
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
@@ -134,9 +136,6 @@ class MoECommMethod(ABC):
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=self.moe_config.
|
||||
global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
@@ -162,12 +161,15 @@ class MoECommMethod(ABC):
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
context_metadata=dispatch_results.context_metadata)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
before_dispatch_evt=before_dispatch_evt,
|
||||
before_combine_evt=before_combine_evt,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list)
|
||||
|
||||
@@ -284,10 +286,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
|
||||
@@ -151,8 +151,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
@@ -162,6 +162,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
@@ -174,7 +175,9 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape
|
||||
}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
|
||||
@@ -190,14 +193,25 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
assert context_metadata is not None
|
||||
|
||||
split_hidden_states = context_metadata["split_hidden_states"]
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
assert context_metadata is not None
|
||||
# Cannot reuse `split_hidden_states` from prepare phase as it
|
||||
# may share memory with original hidden_states. Since shared
|
||||
# experts may use the original tensor, reusing it would cause
|
||||
# in-place modification during all_gather, corrupting the data.
|
||||
padded_hidden_states_shape = context_metadata[
|
||||
"padded_hidden_states_shape"]
|
||||
gathered_hidden_states = torch.empty(
|
||||
padded_hidden_states_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
split_hidden_states = torch.tensor_split(
|
||||
gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(split_hidden_states, dim=0)
|
||||
hidden_states = gathered_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
@@ -249,7 +263,6 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
@@ -257,6 +270,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
@@ -268,6 +282,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
@@ -280,7 +295,9 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape,
|
||||
}
|
||||
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -82,9 +82,6 @@ class MoETokenDispatcher(ABC):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
@@ -193,9 +190,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
@@ -226,12 +220,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"shared_experts": shared_experts,
|
||||
"expand_scales": expand_scales
|
||||
}
|
||||
|
||||
group_list_type = 0
|
||||
|
||||
return TokenDispatchResult(hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
@@ -297,7 +289,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
|
||||
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
return TokenCombineResult(routed_out=combined_output)
|
||||
return TokenCombineResult(routed_out=combined_output, )
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
@@ -319,9 +311,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
@@ -442,9 +431,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
|
||||
@@ -204,9 +204,6 @@ class AscendW4A16FusedMoEMethod:
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -229,24 +226,21 @@ class AscendW4A16FusedMoEMethod:
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
return moe_comm_method.fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get(
|
||||
"mc2_mask", None))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.transpose_weight:
|
||||
|
||||
@@ -341,9 +341,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -390,9 +387,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
use_int4_w4a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
|
||||
|
||||
@@ -190,9 +190,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
pertoken_scale: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -280,9 +277,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
use_int8_w8a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
|
||||
Reference in New Issue
Block a user