Share target model embed and head weights for nextn (#4033)
This commit is contained in:
@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
self.model = DeepseekModelNextN(config, quant_config)
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.model.shared_head.head = ReplicatedLinear(
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
self.model.shared_head.head = ParallelLMHead(
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.shared_head.head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
|
||||
nextn_layer_prefix = "model.layers.0"
|
||||
nextn_spec_weight_names = [
|
||||
"shared_head.head",
|
||||
"shared_head.norm",
|
||||
"eh_proj",
|
||||
"embed_tokens",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
]
|
||||
@@ -180,17 +178,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
for name, loaded_weight in weights:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
else:
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
if is_hip_:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
del self.model.embed_tokens.weight
|
||||
del self.lm_head.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
self.lm_head.weight = head
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user