Auto set draft model path for MTP (#5793)
This commit is contained in:
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def post_load_weights(self):
|
||||
def post_load_weights(self, is_nextn=False):
|
||||
|
||||
# Perform post-processing after loading weights
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
layer_ids = (
|
||||
range(self.config.num_hidden_layers)
|
||||
if not is_nextn
|
||||
else [self.config.num_hidden_layers]
|
||||
)
|
||||
for layer_id in layer_ids:
|
||||
self_attn = (
|
||||
self.model.layers[layer_id].self_attn
|
||||
if not is_nextn
|
||||
else self.model.decoder.self_attn
|
||||
)
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
if _is_cuda:
|
||||
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.w_vc = w_vc.contiguous()
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||
if is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
||||
# compatible with old design
|
||||
nextn_layer_id = (
|
||||
0
|
||||
if self.config.num_hidden_layers == 1
|
||||
else self.config.num_hidden_layers
|
||||
)
|
||||
else:
|
||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
names_to_remove = []
|
||||
for moe_layer in tqdm(
|
||||
|
||||
moe_layers = (
|
||||
range(
|
||||
self.config.first_k_dense_replace,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.moe_layer_freq,
|
||||
),
|
||||
)
|
||||
if not is_nextn
|
||||
else [nextn_layer_id]
|
||||
)
|
||||
|
||||
for moe_layer in tqdm(
|
||||
moe_layers,
|
||||
desc=f"Cloning {self.n_share_experts_fusion} "
|
||||
"replicas of the shared expert into MoE",
|
||||
):
|
||||
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||
|
||||
if is_nextn:
|
||||
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
||||
nextn_spec_weight_names = [
|
||||
"shared_head.norm",
|
||||
"eh_proj",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# TODO(HandH1998): Modify it when nextn is supported.
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
if not is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
else:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self.post_load_weights()
|
||||
self.post_load_weights(is_nextn=is_nextn)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
Reference in New Issue
Block a user