Expert distribution recording without overhead for EPLB (#4957)
This commit is contained in:
@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -109,8 +113,6 @@ if _is_hip:
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
forward_mode = forward_batch.forward_mode
|
||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||
forward_mode, hidden_states
|
||||
):
|
||||
@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@classmethod
|
||||
def get_model_config_for_expert_location(cls, config):
|
||||
return ModelConfigForExpertLocation(
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_logical_experts=config.n_routed_experts,
|
||||
num_groups=config.n_group,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user