Refactor DeepSeek logic into atomic operations (#6326)

This commit is contained in:
fzyzcjy
2025-05-20 12:05:30 +08:00
committed by GitHub
parent 17d080b7ae
commit d0443275f0
3 changed files with 343 additions and 82 deletions

View File

@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -83,6 +83,8 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
@@ -304,99 +306,123 @@ class DeepseekV2MoE(nn.Module):
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
def op_gate(self, state):
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states
state.forward_batch.forward_mode, state.hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
state.router_logits = self.gate(state.hidden_states_mlp_input)
else:
router_logits = None
state.router_logits = None
def op_shared_experts(self, state):
if (self.n_share_experts_fusion == 0) and (
(not self._enable_deepep_moe)
or is_non_idle_and_non_empty(forward_mode, hidden_states)
or is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
)
):
shared_output = self.shared_experts(hidden_states)
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
else:
shared_output = None
state.shared_output = None
if self._enable_deepep_moe and (router_logits is not None):
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self._enable_deepep_moe and (self.ep_size > 1):
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_mode=forward_mode,
)
def op_select_experts(self, state):
router_logits = state.router_logits
hidden_states = state.hidden_states_mlp_input
if self._enable_deepep_moe:
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
else:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
)
def op_dispatch_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b()
def op_experts(self, state):
if self._enable_deepep_moe:
state.pop("router_logits")
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
else:
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_mlp_input"),
router_logits=state.pop("router_logits"),
)
def op_combine_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
self.deepep_dispatcher.combine_a(
state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
)
def op_combine_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
def op_output(self, state):
final_hidden_states = (
state.pop("hidden_states_after_combine")
if self._enable_deepep_moe
else state.pop("hidden_states_experts_output")
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if (s := state.pop("shared_output")) is not None:
final_hidden_states = final_hidden_states + s
if (not self._enable_deepep_moe) and (self.tp_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
state.hidden_states_mlp_output = final_hidden_states
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
@@ -1197,27 +1223,77 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
return execute_operations(
inputs=dict(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
),
operations=compute_layer_operations(self),
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
def op_comm_prepare_attn(
self,
state,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(
hidden_states, residual, state.forward_batch
)
)
state.update(
dict(
forward_batch=forward_batch,
positions=positions,
zero_allocator=zero_allocator,
)
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
def op_attn(self, state):
state.hidden_states_after_attn = self.self_attn(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
)
hidden_states = self.mlp(hidden_states, forward_batch)
def op_comm_prepare_mlp(self, state):
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp(
state.pop("hidden_states_after_attn"),
state.pop("residual_after_input_ln"),
state.forward_batch,
)
)
def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input")
if not (
enable_moe_dense_fully_dp()
and (not self.is_layer_sparse)
and hidden_states.shape[0] == 0
):
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode
)
else:
state.hidden_states_mlp_output = hidden_states
def op_comm_postprocess_layer(self, state):
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
state.pop("hidden_states_mlp_output"),
state.pop("residual_after_comm_pre_mlp"),
state.forward_batch,
)
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
return hidden_states, residual