From f1e9bbaff504975378e39a37c80158c7e8f2610f Mon Sep 17 00:00:00 2001 From: JiLi Date: Sat, 30 Aug 2025 02:19:26 +0800 Subject: [PATCH] feat: Add flexible validation for partial weight updates (#9663) Co-authored-by: RichardW Co-authored-by: Zhuorany Co-authored-by: Stefan He Co-authored-by: Yineng Zhang Co-authored-by: Night <32424487+PrinsYin@users.noreply.github.com> Co-authored-by:zhaochenyang20 Co-authored-by: Liangsheng Yin --- python/sglang/srt/models/gpt_oss.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 35c42d26e..27b49f4ec 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module): ) params_dict = dict(self.named_parameters()) - params_checker = {k: False for k, v in params_dict.items()} - - for other_loaded_param_name in other_loaded_param_names: - params_checker[other_loaded_param_name] = True for name, loaded_weight in weights: loaded_weight = _WeightCreator.maybe_materialize(loaded_weight) @@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - params_checker[name] = True break else: for mapping in expert_params_mapping: @@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module): name, shard_id=shard_id, ) - params_checker[name] = True break else: if name.endswith(".bias") and name not in params_dict: @@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module): param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) - params_checker[name] = True else: logger.warning(f"Parameter {name} not found in params_dict") - not_loaded_params = [k for k, v in params_checker.items() if not v] - if tp_rank == 0: - if len(not_loaded_params) > 0: - raise Exception(f"Not all parameters loaded: {not_loaded_params}") - else: - logging.info("All parameters loaded successfully.") - def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight