diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b4fc4d7a7..ed859a311 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -92,6 +92,7 @@ from sglang.srt.utils import ( BumpAllocator, DeepEPMode, add_prefix, + bind_or_assign, get_bool_env_var, get_int_env_var, is_cuda, @@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) - def post_load_weights(self, is_nextn=False): + def post_load_weights(self, is_nextn=False, weight_names=None): # Perform post-processing after loading weights - layer_ids = ( - range(self.config.num_hidden_layers) - if not is_nextn - else [self.config.num_hidden_layers] - ) + if is_nextn: + layer_ids = [self.config.num_hidden_layers] + else: + if weight_names is None: + layer_ids = range(self.config.num_hidden_layers) + else: + layer_ids = set() + for name in weight_names: + if "kv_b_proj" in name: + layer_id = int(name.split(".")[2]) + # filter the nextn layer. + if layer_id != self.config.num_hidden_layers: + layer_ids.add(layer_id) + for layer_id in layer_ids: self_attn = ( self.model.layers[layer_id].self_attn @@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module): 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) if not use_deep_gemm_bmm: - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + self_attn.w_vc = bind_or_assign( + self_attn.w_vc, w_vc.contiguous().transpose(1, 2) + ) if ( hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale + self_attn.w_scale = bind_or_assign( + self_attn.w_scale, self_attn.kv_b_proj.weight_scale + ) if _is_hip: self_attn.w_scale *= 2.0 else: @@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module): ws_kc, ws_vc = block_scale.unflatten( 0, (-1, (num_tiles_k + num_tiles_n)) ).split([num_tiles_k, num_tiles_n], dim=1) - self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous() - self_attn.w_scale_v = ws_vc.contiguous() - self_attn.w_kc = w_kc.transpose(1, 2).contiguous() - self_attn.w_vc = w_vc.contiguous() + self_attn.w_scale_k = bind_or_assign( + self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + self_attn.w_scale_v = bind_or_assign( + self_attn.w_scale_v, ws_vc.contiguous() + ) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous() + ) + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) self_attn.use_deep_gemm_bmm = True # TODO support nextn later @@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module): ] params_dict = dict(self.named_parameters()) + weight_names = [] for name, loaded_weight in weights: + weight_names.append(name) + if not is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers @@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - self.post_load_weights(is_nextn=is_nextn) + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 707b4c4a4..69ade4a4f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str: except Exception: # If anything fails, return empty string return "" + + +def bind_or_assign(target, source): + if target is not None: + target.copy_(source) + return target + else: + return source