[Feature] Support fine-grained shared expert overlap (#5482)

Fine-grained control over shared expert overlap to prevent resource
contention.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2026-01-17 11:53:22 +08:00
committed by GitHub
parent 48e10de8c9
commit 22f253142a
9 changed files with 203 additions and 130 deletions

View File

@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Optional
from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
@@ -47,7 +50,20 @@ from vllm_ascend.quantization.w8a8_dynamic import \
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled,
shared_experts_calculation_stream)
shared_experts_calculation_stream, vllm_version_is)
@dataclass
class FusedMoEResult:
routed_out: torch.Tensor
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
@dataclass
class FusedMoEEvents:
before_routed_experts: torch.npu.Event
before_dispatch: torch.npu.Event | None = field(default=None)
before_combine: torch.npu.Event | None = field(default=None)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -90,7 +106,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -137,7 +152,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
@@ -268,13 +282,13 @@ class AscendFusedMoE(FusedMoE):
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward_impl( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
assert self.quant_method is not None
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
forward_context = get_forward_context()
# Load balancing for token distribution among experts in dummy_run
@@ -359,9 +373,6 @@ class AscendFusedMoE(FusedMoE):
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
shared_experts=None,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
@@ -380,7 +391,14 @@ class AscendFusedMoE(FusedMoE):
reduce_results=self.reduce_results,
context_metadata=context_metadata)
return routed_out
if return_with_event:
return FusedMoEResult(
routed_out=routed_out,
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
before_combine_evt=fused_experts_results.before_combine_evt)
else:
# The vLLM FusedMoE forward_impl does not return events.
return routed_out
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
@@ -407,6 +425,74 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self._gate = gate
# Wrap the quant_method's process_weights_after_loading to validate that
# splitting shared expert computation (gate_up projection + activation,
# then down projection) yields identical results to integrated
# computation after weight loading.
original_process_weights = self.quant_method.process_weights_after_loading
@wraps(original_process_weights)
def wrapped_process_weights(*args, **kwargs):
result = original_process_weights(*args, **kwargs)
self._validate_shared_expert_consistency()
return result
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj(
hidden_states) # type: ignore
return shared_gate_up
def _shared_experts_part2(self, hidden_states: torch.Tensor,
shared_gate_up: torch.Tensor):
shared_act = self._shared_experts.act_fn(
shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(
shared_act) # type: ignore
# Qwen3-Next specific gating mechanism
if hasattr(self._shared_experts, "expert_gate") and \
self._shared_experts.expert_gate is not None:
if vllm_version_is('0.13.0'):
# TODO(jianzs): remove this branch after vLLM new version is
# released
gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore
else:
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
shared_out = F.sigmoid(gate_out) * shared_out
return shared_out
def _validate_shared_expert_consistency(self):
"""Validate that split shared expert computation matches integrated
computation."""
test_input = torch.rand(
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
integrated_out = self._shared_experts(test_input)
part1_out = self._shared_experts_part1(test_input)
split_out = self._shared_experts_part2(test_input, part1_out)
if not torch.allclose(integrated_out, split_out):
diff = (integrated_out - split_out).abs()
logger.error(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.error(f"Max absolute difference: {diff.max().item()}")
logger.error("Integrated output - sum: %s, norm: %s",
integrated_out.sum().item(),
integrated_out.norm().item())
logger.error("Split output - sum: %s, norm: %s",
split_out.sum().item(),
split_out.norm().item())
raise ValueError(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.info_once(
"SharedFusedMoE shared experts split computation matches the "
"integrated computation.")
@property
def gate(self) -> Optional[torch.nn.Module]:
return self._gate if self.use_overlapped else None
@@ -434,44 +520,67 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
)
return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_out = None
if not self.multistream_overlap_gate:
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream(
).wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(
shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
else:
def _forward_shared_experts(self, hidden_states: torch.Tensor,
fused_moe_evts: FusedMoEEvents):
def maybe_wait_event(evt: torch.npu.Event | None):
if evt is not None:
torch.npu.current_stream().wait_event(evt)
with npu_stream_switch(shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Ensure the shared experts wait for hidden_states to be ready.
torch.npu.current_stream().wait_event(
fused_moe_evts.before_routed_experts)
# Execute the gate projection and activation concurrently with the
# dispatch communication.
maybe_wait_event(fused_moe_evts.before_dispatch)
part1_out = self._shared_experts_part1(hidden_states)
# Execute the down projection concurrently with the combine
# communication.
maybe_wait_event(fused_moe_evts.before_combine)
shared_out = self._shared_experts_part2(hidden_states, part1_out)
# Make sure the default stream waits for the shared experts stream to
# finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out
def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if self.multistream_overlap_gate:
set_flash_common3_context(shared_experts=self._shared_experts)
routed_out = AscendFusedMoE.forward_impl(
before_routed_experts = torch.npu.current_stream().record_event()
fused_moe_results = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
return_with_event=True,
)
routed_out = fused_moe_results.routed_out
if not self.multistream_overlap_gate:
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
shared_out = fc3_context.shared_out
else:
shared_out = self._forward_shared_experts(
hidden_states,
FusedMoEEvents(
before_routed_experts=before_routed_experts,
before_dispatch=fused_moe_results.before_dispatch_evt,
before_combine=fused_moe_results.before_combine_evt,
))
return shared_out, routed_out