Make single-batch overlap compatible with NextN (#11804)

This commit is contained in:
fzyzcjy
2025-10-19 16:10:44 +08:00
committed by GitHub
parent ea6275dfbc
commit ce399e154c
3 changed files with 9 additions and 4 deletions

View File

@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE):
forward_batch: ForwardBatch,
forward_shared_experts=None,
alt_stream=None,
disable_sbo=False,
):
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo(
@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE):
experts=self,
forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream,
disable_sbo=disable_sbo,
)
def dispatch(

View File

@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module):
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
# SBO is not yet implemented for NextN
disable_sbo=self.is_nextn,
)
if self._fuse_shared_experts_inside_sbo
else {}

View File

@@ -60,13 +60,14 @@ def execute_sbo(
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
disable_sbo: bool = False,
):
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream)
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
)
hidden_states = experts.moe_impl(
@@ -75,7 +76,7 @@ def execute_sbo(
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
e.record()
if SboFlags.enable_combine_shared_two_stream_overlap():
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
# TODO reduce sm for non-deepgemm
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"]
@@ -93,8 +94,8 @@ def execute_sbo(
return hidden_states
def _compute_overlap_args(dispatch_output, alt_stream):
if not (
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
if disable_sbo or not (
SboFlags.enable_combine_down_gemm_two_stream_overlap()
or SboFlags.enable_combine_shared_two_stream_overlap()
):