Make single-batch overlap compatible with offloading (#11614)

This commit is contained in:
fzyzcjy
2025-10-18 08:45:54 +08:00
committed by GitHub
parent dcb8f090ad
commit 33e9bbec35
3 changed files with 33 additions and 23 deletions

View File

@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if not SboFlags.fuse_shared_experts_inside_sbo():
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device
)
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
experts=self.experts,
alt_stream=self.alt_stream,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None:
x = shared_output