Support overlapping two batches (#4068)
This commit is contained in:
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.operations import execute_operations
|
||||
from sglang.srt.operations_strategy import compute_layer_operations
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None
|
||||
)
|
||||
|
||||
self.deepep_dispatcher = DeepEPDispatcher(
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
async_finish=True, # TODO
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _enable_deepep_moe(self):
|
||||
return global_server_args_dict["enable_deepep_moe"]
|
||||
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
return None
|
||||
|
||||
def op_gate(self, state):
|
||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||
if is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
|
||||
state.router_logits = None
|
||||
|
||||
def op_shared_experts(self, state):
|
||||
if (self.n_share_experts_fusion == 0) and (
|
||||
(not self._enable_deepep_moe)
|
||||
or is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
)
|
||||
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
|
||||
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, hidden_states_mlp_input
|
||||
):
|
||||
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
|
||||
state.shared_output = self.shared_experts(hidden_states_mlp_input)
|
||||
else:
|
||||
state.shared_output = None
|
||||
|
||||
def op_select_experts(self, state):
|
||||
router_logits = state.router_logits
|
||||
router_logits = state.pop("router_logits")
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
|
||||
if self._enable_deepep_moe:
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
hidden_states=state.hidden_states_mlp_input,
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_dispatch_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b()
|
||||
if self.ep_size > 1:
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
if self._enable_deepep_moe:
|
||||
state.pop("router_logits")
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
else:
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
router_logits=state.pop("router_logits"),
|
||||
)
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
if self.ep_size > 1:
|
||||
self.deepep_dispatcher.combine_a(
|
||||
state.pop("hidden_states_experts_output"),
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
final_hidden_states = (
|
||||
state.pop("hidden_states_after_combine")
|
||||
if self._enable_deepep_moe
|
||||
else state.pop("hidden_states_experts_output")
|
||||
)
|
||||
|
||||
final_hidden_states = state.pop("hidden_states_after_combine")
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if (s := state.pop("shared_output")) is not None:
|
||||
final_hidden_states = final_hidden_states + s
|
||||
|
||||
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
state.hidden_states_mlp_output = final_hidden_states
|
||||
|
||||
|
||||
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
tbo_subbatch_index: Optional[int] = None,
|
||||
):
|
||||
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
||||
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
||||
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
positions=positions,
|
||||
zero_allocator=zero_allocator,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
state.forward_batch,
|
||||
)
|
||||
|
||||
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
|
||||
return hidden_states, residual
|
||||
output = dict(
|
||||
positions=state.positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=state.forward_batch,
|
||||
zero_allocator=state.zero_allocator,
|
||||
tbo_subbatch_index=state.tbo_subbatch_index,
|
||||
)
|
||||
|
||||
state.clear(
|
||||
expect_keys={
|
||||
"positions",
|
||||
"forward_batch",
|
||||
"zero_allocator",
|
||||
"tbo_subbatch_index",
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
self.padding_id = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
total_num_layers = len(self.layers)
|
||||
device = input_embeds.device if input_embeds is not None else input_ids.device
|
||||
zero_allocator = BumpAllocator(
|
||||
# TODO for two-batch-overlap, we need a larger buffer size
|
||||
buffer_size=len(self.layers) * 2,
|
||||
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
||||
dtype=torch.float32,
|
||||
device=(
|
||||
input_embeds.device if input_embeds is not None else input_ids.device
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
|
||||
normal_num_layers = (
|
||||
self.first_k_dense_replace
|
||||
if forward_batch.can_run_tbo
|
||||
else total_num_layers
|
||||
)
|
||||
for i in range(normal_num_layers):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
|
||||
if normal_num_layers != total_num_layers:
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers[normal_num_layers:],
|
||||
enable_tbo=True,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
|
||||
return self.logits_processor(
|
||||
|
||||
Reference in New Issue
Block a user