Support recording experts workload in QWen2-MoE (#4775)
This commit is contained in:
@@ -44,10 +44,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -366,6 +369,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states = input_embeds
|
hidden_states = input_embeds
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
|
expert_distribution_recorder.set_current_layer(i)
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
|
|||||||
Reference in New Issue
Block a user