Support single batch overlap (#10422)

This commit is contained in:
fzyzcjy
2025-10-02 18:04:36 +08:00
committed by GitHub
parent 0b9dfba787
commit 5e786cca3a
9 changed files with 268 additions and 20 deletions

View File

@@ -28,6 +28,7 @@ from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
if not SboFlags.fuse_shared_experts_inside_sbo():
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device
)
final_hidden_states = self.experts(
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
experts=self.experts,
alt_stream=self.alt_stream,
)
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None:
x = shared_output
@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
def _forward_shared_experts(
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
):
if self.num_fused_shared_experts == 0:
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
return self.shared_experts(
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
)