Support overlapping two batches (#4068)

This commit is contained in:
fzyzcjy
2025-05-25 08:39:07 +08:00
committed by GitHub
parent f456037396
commit 0d47788025
13 changed files with 1145 additions and 129 deletions

View File

@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
self.config = config
self.layer_id = layer_id
if self.tp_size > config.n_routed_experts:
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
else None
)
self.deepep_dispatcher = DeepEPDispatcher(
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
async_finish=True,
return_recv_hook=True,
)
@property
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
def get_moe_weights(self):
return [
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
return None
def op_gate(self, state):
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
if is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
state.router_logits = None
def op_shared_experts(self, state):
if (self.n_share_experts_fusion == 0) and (
(not self._enable_deepep_moe)
or is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
)
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
state.forward_batch.forward_mode, hidden_states_mlp_input
):
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
state.shared_output = self.shared_experts(hidden_states_mlp_input)
else:
state.shared_output = None
def op_select_experts(self, state):
router_logits = state.router_logits
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if self._enable_deepep_moe:
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_dispatch_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b()
if self.ep_size > 1:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
if self._enable_deepep_moe:
state.pop("router_logits")
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
else:
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_mlp_input"),
router_logits=state.pop("router_logits"),
)
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
def op_combine_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
if self.ep_size > 1:
self.deepep_dispatcher.combine_a(
state.pop("hidden_states_experts_output"),
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_combine_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_output(self, state):
final_hidden_states = (
state.pop("hidden_states_after_combine")
if self._enable_deepep_moe
else state.pop("hidden_states_experts_output")
)
final_hidden_states = state.pop("hidden_states_after_combine")
final_hidden_states *= self.routed_scaling_factor
if (s := state.pop("shared_output")) is not None:
final_hidden_states = final_hidden_states + s
if (not self._enable_deepep_moe) and (self.tp_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
state.hidden_states_mlp_output = final_hidden_states
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
tbo_subbatch_index: Optional[int] = None,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
positions=positions,
zero_allocator=zero_allocator,
tbo_subbatch_index=tbo_subbatch_index,
)
)
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
state.forward_batch,
)
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
return hidden_states, residual
output = dict(
positions=state.positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
tbo_subbatch_index=state.tbo_subbatch_index,
)
state.clear(
expect_keys={
"positions",
"forward_batch",
"zero_allocator",
"tbo_subbatch_index",
}
)
return output
class DeepseekV2Model(nn.Module):
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
super().__init__()
self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
total_num_layers = len(self.layers)
device = input_embeds.device if input_embeds is not None else input_ids.device
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
device=device,
)
if input_embeds is None:
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
normal_num_layers = (
self.first_k_dense_replace
if forward_batch.can_run_tbo
else total_num_layers
)
for i in range(normal_num_layers):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if normal_num_layers != total_num_layers:
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers[normal_num_layers:],
enable_tbo=True,
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
zero_allocator=zero_allocator,
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(