Support single batch overlap (#10422)
This commit is contained in:
@@ -28,6 +28,7 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.single_batch_overlap import SboFlags
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
if hidden_states.shape[0] > 0:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
if not SboFlags.fuse_shared_experts_inside_sbo():
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
# SBO args
|
||||
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
||||
experts=self.experts,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
if sbo_shared_output is not None:
|
||||
shared_output = sbo_shared_output
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
def _forward_shared_experts(
|
||||
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
||||
):
|
||||
if self.num_fused_shared_experts == 0:
|
||||
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
||||
return self.shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user