[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (#6265)
This commit is contained in:
@@ -92,6 +92,7 @@ from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
add_prefix,
|
||||
bind_or_assign,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda,
|
||||
@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def post_load_weights(self, is_nextn=False):
|
||||
def post_load_weights(self, is_nextn=False, weight_names=None):
|
||||
|
||||
# Perform post-processing after loading weights
|
||||
layer_ids = (
|
||||
range(self.config.num_hidden_layers)
|
||||
if not is_nextn
|
||||
else [self.config.num_hidden_layers]
|
||||
)
|
||||
if is_nextn:
|
||||
layer_ids = [self.config.num_hidden_layers]
|
||||
else:
|
||||
if weight_names is None:
|
||||
layer_ids = range(self.config.num_hidden_layers)
|
||||
else:
|
||||
layer_ids = set()
|
||||
for name in weight_names:
|
||||
if "kv_b_proj" in name:
|
||||
layer_id = int(name.split(".")[2])
|
||||
# filter the nextn layer.
|
||||
if layer_id != self.config.num_hidden_layers:
|
||||
layer_ids.add(layer_id)
|
||||
|
||||
for layer_id in layer_ids:
|
||||
self_attn = (
|
||||
self.model.layers[layer_id].self_attn
|
||||
@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
if not use_deep_gemm_bmm:
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
self_attn.w_kc = bind_or_assign(
|
||||
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
)
|
||||
self_attn.w_vc = bind_or_assign(
|
||||
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
|
||||
)
|
||||
if (
|
||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
self_attn.w_scale = bind_or_assign(
|
||||
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
|
||||
)
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
else:
|
||||
@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
ws_kc, ws_vc = block_scale.unflatten(
|
||||
0, (-1, (num_tiles_k + num_tiles_n))
|
||||
).split([num_tiles_k, num_tiles_n], dim=1)
|
||||
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
||||
self_attn.w_scale_v = ws_vc.contiguous()
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
||||
self_attn.w_vc = w_vc.contiguous()
|
||||
self_attn.w_scale_k = bind_or_assign(
|
||||
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
|
||||
)
|
||||
self_attn.w_scale_v = bind_or_assign(
|
||||
self_attn.w_scale_v, ws_vc.contiguous()
|
||||
)
|
||||
self_attn.w_kc = bind_or_assign(
|
||||
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
|
||||
)
|
||||
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
# TODO support nextn later
|
||||
@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
weight_names = []
|
||||
for name, loaded_weight in weights:
|
||||
weight_names.append(name)
|
||||
|
||||
if not is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self.post_load_weights(is_nextn=is_nextn)
|
||||
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
|
||||
except Exception:
|
||||
# If anything fails, return empty string
|
||||
return ""
|
||||
|
||||
|
||||
def bind_or_assign(target, source):
|
||||
if target is not None:
|
||||
target.copy_(source)
|
||||
return target
|
||||
else:
|
||||
return source
|
||||
|
||||
Reference in New Issue
Block a user