Share target model embed and head weights for nextn (#4033)

This commit is contained in:
Ke Bao
2025-03-04 05:30:04 +08:00
committed by GitHub
parent 146ac8df07
commit 9fafa62db7
7 changed files with 47 additions and 45 deletions

View File

@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
if is_hip_:
self_attn.w_scale *= 2.0
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass