[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (#6265)
This commit is contained in:
@@ -92,6 +92,7 @@ from sglang.srt.utils import (
|
|||||||
BumpAllocator,
|
BumpAllocator,
|
||||||
DeepEPMode,
|
DeepEPMode,
|
||||||
add_prefix,
|
add_prefix,
|
||||||
|
bind_or_assign,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_load_weights(self, is_nextn=False):
|
def post_load_weights(self, is_nextn=False, weight_names=None):
|
||||||
|
|
||||||
# Perform post-processing after loading weights
|
# Perform post-processing after loading weights
|
||||||
layer_ids = (
|
if is_nextn:
|
||||||
range(self.config.num_hidden_layers)
|
layer_ids = [self.config.num_hidden_layers]
|
||||||
if not is_nextn
|
else:
|
||||||
else [self.config.num_hidden_layers]
|
if weight_names is None:
|
||||||
)
|
layer_ids = range(self.config.num_hidden_layers)
|
||||||
|
else:
|
||||||
|
layer_ids = set()
|
||||||
|
for name in weight_names:
|
||||||
|
if "kv_b_proj" in name:
|
||||||
|
layer_id = int(name.split(".")[2])
|
||||||
|
# filter the nextn layer.
|
||||||
|
if layer_id != self.config.num_hidden_layers:
|
||||||
|
layer_ids.add(layer_id)
|
||||||
|
|
||||||
for layer_id in layer_ids:
|
for layer_id in layer_ids:
|
||||||
self_attn = (
|
self_attn = (
|
||||||
self.model.layers[layer_id].self_attn
|
self.model.layers[layer_id].self_attn
|
||||||
@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
if not use_deep_gemm_bmm:
|
if not use_deep_gemm_bmm:
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
self_attn.w_kc = bind_or_assign(
|
||||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
|
)
|
||||||
|
self_attn.w_vc = bind_or_assign(
|
||||||
|
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||||
and self_attn.w_scale is None
|
and self_attn.w_scale is None
|
||||||
):
|
):
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
self_attn.w_scale = bind_or_assign(
|
||||||
|
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
|
||||||
|
)
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
self_attn.w_scale *= 2.0
|
self_attn.w_scale *= 2.0
|
||||||
else:
|
else:
|
||||||
@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
ws_kc, ws_vc = block_scale.unflatten(
|
ws_kc, ws_vc = block_scale.unflatten(
|
||||||
0, (-1, (num_tiles_k + num_tiles_n))
|
0, (-1, (num_tiles_k + num_tiles_n))
|
||||||
).split([num_tiles_k, num_tiles_n], dim=1)
|
).split([num_tiles_k, num_tiles_n], dim=1)
|
||||||
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
self_attn.w_scale_k = bind_or_assign(
|
||||||
self_attn.w_scale_v = ws_vc.contiguous()
|
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
)
|
||||||
self_attn.w_vc = w_vc.contiguous()
|
self_attn.w_scale_v = bind_or_assign(
|
||||||
|
self_attn.w_scale_v, ws_vc.contiguous()
|
||||||
|
)
|
||||||
|
self_attn.w_kc = bind_or_assign(
|
||||||
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
|
||||||
|
)
|
||||||
|
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
||||||
self_attn.use_deep_gemm_bmm = True
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
# TODO support nextn later
|
# TODO support nextn later
|
||||||
@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
weight_names = []
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
weight_names.append(name)
|
||||||
|
|
||||||
if not is_nextn:
|
if not is_nextn:
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
self.post_load_weights(is_nextn=is_nextn)
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|||||||
@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# If anything fails, return empty string
|
# If anything fails, return empty string
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def bind_or_assign(target, source):
|
||||||
|
if target is not None:
|
||||||
|
target.copy_(source)
|
||||||
|
return target
|
||||||
|
else:
|
||||||
|
return source
|
||||||
|
|||||||
Reference in New Issue
Block a user