diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 12d6a63e..50618ae8 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -415,8 +415,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self.use_overlapped = use_overlapped self.shared_expert_stream = None ascend_config = get_ascend_config() - self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert - self.multistream_overlap_gate = ascend_config.multistream_overlap_gate + self.multistream_overlap_shared_expert = \ + ascend_config.multistream_overlap_shared_expert and \ + self._shared_experts is not None + self.multistream_overlap_gate = \ + ascend_config.multistream_overlap_gate and \ + self._shared_experts is not None if enable_sp(): logger.info_once( "Sequence parallelism is enabled, shared experts are replicated for best performance." @@ -424,19 +428,20 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self._gate = gate - # Wrap the quant_method's process_weights_after_loading to validate that - # splitting shared expert computation (gate_up projection + activation, - # then down projection) yields identical results to integrated - # computation after weight loading. - original_process_weights = self.quant_method.process_weights_after_loading + if self.multistream_overlap_shared_expert: + # Wrap the quant_method's process_weights_after_loading to validate that + # splitting shared expert computation (gate_up projection + activation, + # then down projection) yields identical results to integrated + # computation after weight loading. + original_process_weights = self.quant_method.process_weights_after_loading - @wraps(original_process_weights) - def wrapped_process_weights(*args, **kwargs): - result = original_process_weights(*args, **kwargs) - self._validate_shared_expert_consistency() - return result + @wraps(original_process_weights) + def wrapped_process_weights(*args, **kwargs): + result = original_process_weights(*args, **kwargs) + self._validate_shared_expert_consistency() + return result - self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore + self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore def _shared_experts_part1(self, hidden_states: torch.Tensor): shared_gate_up, _ = self._shared_experts.gate_up_proj( @@ -516,6 +521,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents): + if self._shared_experts is None: + return None def maybe_wait_event(evt: torch.npu.Event | None): if evt is not None: