Clean up wrapper in flashinfer backend (#2638)

This commit is contained in:
Lianmin Zheng
2024-12-29 00:45:57 -08:00
committed by GitHub
parent fd34f2da35
commit 3815b23ccb
12 changed files with 197 additions and 94 deletions

View File

@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
)
return None
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
class Phi3ForCausalLM(LlamaForCausalLM):
pass