diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 1b06ce6e5..e63f745ba 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -64,7 +64,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers +from sglang.srt.utils import ( + LazyValue, + add_prefix, + is_cuda, + is_flashinfer_available, + make_layers, +) _is_cuda = is_cuda() _is_flashinfer_available = is_flashinfer_available() @@ -655,6 +661,18 @@ class GptOssForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock) + } + ) + + @property + def routed_experts_weights_of_layer(self): + return self._routed_experts_weights_of_layer.value + @torch.no_grad() def forward( self, @@ -1138,12 +1156,6 @@ class GptOssForCausalLM(nn.Module): else: logging.info("All parameters loaded successfully.") - self.routed_experts_weights_of_layer = { - layer_id: self.model.layers[layer_id].mlp.get_moe_weights() - for layer_id in range(self.start_layer, self.end_layer) - if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock) - } - def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight