[Main2Main][BugFix] Add shared_experts check for AscendSharedFusedMoE (#6335)

### What this PR does / why we need it?
PR https://github.com/vllm-project/vllm/pull/32082 in vLLM makes
Qwen3-Moe models also go into `SharedFusedMoE`, while current
implementation of our `AscendSharedFusedMoE` assumes shared_experts
always exist. This PR adds checking to
`multistream_overlap_shared_expert` and `multistream_overlap_gate` in
order to only enable these features when shared experts exist.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
All ci passed

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-01-29 08:47:20 +08:00
committed by GitHub
parent f0ff2cc22d
commit 39f8af9d96

View File

@@ -415,8 +415,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self._shared_experts is not None
self.multistream_overlap_gate = \
ascend_config.multistream_overlap_gate and \
self._shared_experts is not None
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -424,19 +428,20 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self._gate = gate
# Wrap the quant_method's process_weights_after_loading to validate that
# splitting shared expert computation (gate_up projection + activation,
# then down projection) yields identical results to integrated
# computation after weight loading.
original_process_weights = self.quant_method.process_weights_after_loading
if self.multistream_overlap_shared_expert:
# Wrap the quant_method's process_weights_after_loading to validate that
# splitting shared expert computation (gate_up projection + activation,
# then down projection) yields identical results to integrated
# computation after weight loading.
original_process_weights = self.quant_method.process_weights_after_loading
@wraps(original_process_weights)
def wrapped_process_weights(*args, **kwargs):
result = original_process_weights(*args, **kwargs)
self._validate_shared_expert_consistency()
return result
@wraps(original_process_weights)
def wrapped_process_weights(*args, **kwargs):
result = original_process_weights(*args, **kwargs)
self._validate_shared_expert_consistency()
return result
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj(
@@ -516,6 +521,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def _forward_shared_experts(self, hidden_states: torch.Tensor,
fused_moe_evts: FusedMoEEvents):
if self._shared_experts is None:
return None
def maybe_wait_event(evt: torch.npu.Event | None):
if evt is not None: