Share target model embed and head weights for nextn (#4033)
This commit is contained in:
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
if is_hip_:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
del self.model.embed_tokens.weight
|
||||
del self.lm_head.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
self.lm_head.weight = head
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user