Make single-batch overlap compatible with NextN (#11804)
This commit is contained in:
@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
forward_shared_experts=None,
|
forward_shared_experts=None,
|
||||||
alt_stream=None,
|
alt_stream=None,
|
||||||
|
disable_sbo=False,
|
||||||
):
|
):
|
||||||
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
||||||
return single_batch_overlap.execute_sbo(
|
return single_batch_overlap.execute_sbo(
|
||||||
@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
experts=self,
|
experts=self,
|
||||||
forward_shared_experts=forward_shared_experts,
|
forward_shared_experts=forward_shared_experts,
|
||||||
alt_stream=alt_stream,
|
alt_stream=alt_stream,
|
||||||
|
disable_sbo=disable_sbo,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
|
|||||||
@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
dict(
|
dict(
|
||||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||||
alt_stream=self.alt_stream,
|
alt_stream=self.alt_stream,
|
||||||
|
# SBO is not yet implemented for NextN
|
||||||
|
disable_sbo=self.is_nextn,
|
||||||
)
|
)
|
||||||
if self._fuse_shared_experts_inside_sbo
|
if self._fuse_shared_experts_inside_sbo
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -60,13 +60,14 @@ def execute_sbo(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
alt_stream: Optional = None,
|
alt_stream: Optional = None,
|
||||||
|
disable_sbo: bool = False,
|
||||||
):
|
):
|
||||||
dispatch_output = experts.dispatch(
|
dispatch_output = experts.dispatch(
|
||||||
hidden_states, topk_idx, topk_weights, forward_batch
|
hidden_states, topk_idx, topk_weights, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||||
_compute_overlap_args(dispatch_output, alt_stream)
|
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = experts.moe_impl(
|
hidden_states = experts.moe_impl(
|
||||||
@@ -75,7 +76,7 @@ def execute_sbo(
|
|||||||
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||||
e.record()
|
e.record()
|
||||||
|
|
||||||
if SboFlags.enable_combine_shared_two_stream_overlap():
|
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
|
||||||
# TODO reduce sm for non-deepgemm
|
# TODO reduce sm for non-deepgemm
|
||||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||||
meta_overlap_args["compute_num_sms"]
|
meta_overlap_args["compute_num_sms"]
|
||||||
@@ -93,8 +94,8 @@ def execute_sbo(
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def _compute_overlap_args(dispatch_output, alt_stream):
|
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
||||||
if not (
|
if disable_sbo or not (
|
||||||
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
||||||
or SboFlags.enable_combine_shared_two_stream_overlap()
|
or SboFlags.enable_combine_shared_two_stream_overlap()
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user