Expert distribution recording without overhead for EPLB (#4957)

This commit is contained in:
fzyzcjy
2025-05-20 11:07:43 +08:00
committed by GitHub
parent b146555749
commit f0653886a5
12 changed files with 1123 additions and 194 deletions

View File

@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states
):
@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual = None
for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.cuda.empty_cache()
torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass