add qwen2 eagle model (#2863)

This commit is contained in:
Lzhang-hub
2025-01-13 21:29:33 +08:00
committed by GitHub
parent d855653bd4
commit 6ec75e626d
2 changed files with 142 additions and 0 deletions

View File

@@ -362,5 +362,16 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = Qwen2ForCausalLM