[Main2Main][BugFix] Add shared_experts check for AscendSharedFusedMoE (#6335)
### What this PR does / why we need it?
PR https://github.com/vllm-project/vllm/pull/32082 in vLLM makes
Qwen3-Moe models also go into `SharedFusedMoE`, while current
implementation of our `AscendSharedFusedMoE` assumes shared_experts
always exist. This PR adds checking to
`multistream_overlap_shared_expert` and `multistream_overlap_gate` in
order to only enable these features when shared experts exist.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
All ci passed
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -415,8 +415,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
self.use_overlapped = use_overlapped
|
self.use_overlapped = use_overlapped
|
||||||
self.shared_expert_stream = None
|
self.shared_expert_stream = None
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
self.multistream_overlap_shared_expert = \
|
||||||
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
ascend_config.multistream_overlap_shared_expert and \
|
||||||
|
self._shared_experts is not None
|
||||||
|
self.multistream_overlap_gate = \
|
||||||
|
ascend_config.multistream_overlap_gate and \
|
||||||
|
self._shared_experts is not None
|
||||||
if enable_sp():
|
if enable_sp():
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||||
@@ -424,19 +428,20 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
|
|
||||||
self._gate = gate
|
self._gate = gate
|
||||||
|
|
||||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
if self.multistream_overlap_shared_expert:
|
||||||
# splitting shared expert computation (gate_up projection + activation,
|
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||||
# then down projection) yields identical results to integrated
|
# splitting shared expert computation (gate_up projection + activation,
|
||||||
# computation after weight loading.
|
# then down projection) yields identical results to integrated
|
||||||
original_process_weights = self.quant_method.process_weights_after_loading
|
# computation after weight loading.
|
||||||
|
original_process_weights = self.quant_method.process_weights_after_loading
|
||||||
|
|
||||||
@wraps(original_process_weights)
|
@wraps(original_process_weights)
|
||||||
def wrapped_process_weights(*args, **kwargs):
|
def wrapped_process_weights(*args, **kwargs):
|
||||||
result = original_process_weights(*args, **kwargs)
|
result = original_process_weights(*args, **kwargs)
|
||||||
self._validate_shared_expert_consistency()
|
self._validate_shared_expert_consistency()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||||
|
|
||||||
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
||||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
||||||
@@ -516,6 +521,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
|
|
||||||
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
||||||
fused_moe_evts: FusedMoEEvents):
|
fused_moe_evts: FusedMoEEvents):
|
||||||
|
if self._shared_experts is None:
|
||||||
|
return None
|
||||||
|
|
||||||
def maybe_wait_event(evt: torch.npu.Event | None):
|
def maybe_wait_event(evt: torch.npu.Event | None):
|
||||||
if evt is not None:
|
if evt is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user