Make single-batch overlap compatible with NextN (#11804)
This commit is contained in:
@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE):
|
||||
forward_batch: ForwardBatch,
|
||||
forward_shared_experts=None,
|
||||
alt_stream=None,
|
||||
disable_sbo=False,
|
||||
):
|
||||
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
||||
return single_batch_overlap.execute_sbo(
|
||||
@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE):
|
||||
experts=self,
|
||||
forward_shared_experts=forward_shared_experts,
|
||||
alt_stream=alt_stream,
|
||||
disable_sbo=disable_sbo,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
|
||||
@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
dict(
|
||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||
alt_stream=self.alt_stream,
|
||||
# SBO is not yet implemented for NextN
|
||||
disable_sbo=self.is_nextn,
|
||||
)
|
||||
if self._fuse_shared_experts_inside_sbo
|
||||
else {}
|
||||
|
||||
@@ -60,13 +60,14 @@ def execute_sbo(
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
alt_stream: Optional = None,
|
||||
disable_sbo: bool = False,
|
||||
):
|
||||
dispatch_output = experts.dispatch(
|
||||
hidden_states, topk_idx, topk_weights, forward_batch
|
||||
)
|
||||
|
||||
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||
_compute_overlap_args(dispatch_output, alt_stream)
|
||||
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
||||
)
|
||||
|
||||
hidden_states = experts.moe_impl(
|
||||
@@ -75,7 +76,7 @@ def execute_sbo(
|
||||
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||
e.record()
|
||||
|
||||
if SboFlags.enable_combine_shared_two_stream_overlap():
|
||||
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
|
||||
# TODO reduce sm for non-deepgemm
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
meta_overlap_args["compute_num_sms"]
|
||||
@@ -93,8 +94,8 @@ def execute_sbo(
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _compute_overlap_args(dispatch_output, alt_stream):
|
||||
if not (
|
||||
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
||||
if disable_sbo or not (
|
||||
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
||||
or SboFlags.enable_combine_shared_two_stream_overlap()
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user