[Main2Main][BugFix] Add shared_experts check for AscendSharedFusedMoE (#6335)
### What this PR does / why we need it?
PR https://github.com/vllm-project/vllm/pull/32082 in vLLM makes
Qwen3-Moe models also go into `SharedFusedMoE`, while current
implementation of our `AscendSharedFusedMoE` assumes shared_experts
always exist. This PR adds checking to
`multistream_overlap_shared_expert` and `multistream_overlap_gate` in
order to only enable these features when shared experts exist.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
All ci passed
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -415,8 +415,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
self.use_overlapped = use_overlapped
|
||||
self.shared_expert_stream = None
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||
self.multistream_overlap_shared_expert = \
|
||||
ascend_config.multistream_overlap_shared_expert and \
|
||||
self._shared_experts is not None
|
||||
self.multistream_overlap_gate = \
|
||||
ascend_config.multistream_overlap_gate and \
|
||||
self._shared_experts is not None
|
||||
if enable_sp():
|
||||
logger.info_once(
|
||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||
@@ -424,19 +428,20 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
self._gate = gate
|
||||
|
||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||
# splitting shared expert computation (gate_up projection + activation,
|
||||
# then down projection) yields identical results to integrated
|
||||
# computation after weight loading.
|
||||
original_process_weights = self.quant_method.process_weights_after_loading
|
||||
if self.multistream_overlap_shared_expert:
|
||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||
# splitting shared expert computation (gate_up projection + activation,
|
||||
# then down projection) yields identical results to integrated
|
||||
# computation after weight loading.
|
||||
original_process_weights = self.quant_method.process_weights_after_loading
|
||||
|
||||
@wraps(original_process_weights)
|
||||
def wrapped_process_weights(*args, **kwargs):
|
||||
result = original_process_weights(*args, **kwargs)
|
||||
self._validate_shared_expert_consistency()
|
||||
return result
|
||||
@wraps(original_process_weights)
|
||||
def wrapped_process_weights(*args, **kwargs):
|
||||
result = original_process_weights(*args, **kwargs)
|
||||
self._validate_shared_expert_consistency()
|
||||
return result
|
||||
|
||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||
|
||||
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
||||
@@ -516,6 +521,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
||||
fused_moe_evts: FusedMoEEvents):
|
||||
if self._shared_experts is None:
|
||||
return None
|
||||
|
||||
def maybe_wait_event(evt: torch.npu.Event | None):
|
||||
if evt is not None:
|
||||
|
||||
Reference in New Issue
Block a user