Make single-batch overlap compatible with offloading (#11614)
This commit is contained in:
@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
if hidden_states.shape[0] > 0:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
if not SboFlags.fuse_shared_experts_inside_sbo():
|
||||
if not self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
||||
if self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = None
|
||||
|
||||
def _forward_shared_experts_and_put_results():
|
||||
nonlocal shared_output
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
# SBO args
|
||||
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
||||
experts=self.experts,
|
||||
alt_stream=self.alt_stream,
|
||||
**(
|
||||
dict(
|
||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
if self._fuse_shared_experts_inside_sbo
|
||||
else {}
|
||||
),
|
||||
)
|
||||
if sbo_shared_output is not None:
|
||||
shared_output = sbo_shared_output
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
|
||||
Reference in New Issue
Block a user