Make single-batch overlap compatible with offloading (#11614)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt import single_batch_overlap
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
get_deepep_mode,
|
get_deepep_mode,
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
@@ -167,18 +168,20 @@ class DeepEPMoE(FusedMoE):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
forward_shared_experts=None,
|
||||||
|
alt_stream=None,
|
||||||
):
|
):
|
||||||
dispatch_output = self.dispatch(
|
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
||||||
hidden_states, topk_idx, topk_weights, forward_batch
|
return single_batch_overlap.execute_sbo(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
# SBO args
|
||||||
|
experts=self,
|
||||||
|
forward_shared_experts=forward_shared_experts,
|
||||||
|
alt_stream=alt_stream,
|
||||||
)
|
)
|
||||||
hidden_states = self.moe_impl(dispatch_output)
|
|
||||||
hidden_states = self.combine(
|
|
||||||
hidden_states,
|
|
||||||
dispatch_output.topk_idx,
|
|
||||||
dispatch_output.topk_weights,
|
|
||||||
forward_batch,
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
if not SboFlags.fuse_shared_experts_inside_sbo():
|
if not self._fuse_shared_experts_inside_sbo:
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_weights, topk_idx, _ = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states.device
|
hidden_states.device
|
||||||
)
|
)
|
||||||
|
|
||||||
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
if self._fuse_shared_experts_inside_sbo:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
def _forward_shared_experts_and_put_results():
|
||||||
|
nonlocal shared_output
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
|
||||||
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
# SBO args
|
**(
|
||||||
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
dict(
|
||||||
experts=self.experts,
|
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||||
alt_stream=self.alt_stream,
|
alt_stream=self.alt_stream,
|
||||||
|
)
|
||||||
|
if self._fuse_shared_experts_inside_sbo
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if sbo_shared_output is not None:
|
|
||||||
shared_output = sbo_shared_output
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
x = shared_output
|
x = shared_output
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class CombineOverlapArgs:
|
|||||||
wait_event: torch.cuda.Event
|
wait_event: torch.cuda.Event
|
||||||
num_sms: int
|
num_sms: int
|
||||||
signal: Optional[torch.Tensor] = None
|
signal: Optional[torch.Tensor] = None
|
||||||
threshold: int = -1
|
threshold: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -61,8 +61,6 @@ def execute_sbo(
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
alt_stream: Optional = None,
|
alt_stream: Optional = None,
|
||||||
):
|
):
|
||||||
shared_output = None
|
|
||||||
|
|
||||||
dispatch_output = experts.dispatch(
|
dispatch_output = experts.dispatch(
|
||||||
hidden_states, topk_idx, topk_weights, forward_batch
|
hidden_states, topk_idx, topk_weights, forward_batch
|
||||||
)
|
)
|
||||||
@@ -82,7 +80,7 @@ def execute_sbo(
|
|||||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||||
meta_overlap_args["compute_num_sms"]
|
meta_overlap_args["compute_num_sms"]
|
||||||
):
|
):
|
||||||
shared_output = forward_shared_experts()
|
forward_shared_experts()
|
||||||
|
|
||||||
hidden_states = experts.combine(
|
hidden_states = experts.combine(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -92,7 +90,7 @@ def execute_sbo(
|
|||||||
overlap_args=combine_overlap_args,
|
overlap_args=combine_overlap_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states, shared_output
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def _compute_overlap_args(dispatch_output, alt_stream):
|
def _compute_overlap_args(dispatch_output, alt_stream):
|
||||||
|
|||||||
Reference in New Issue
Block a user