Fix oom issues with fp8 for llama (#1454)

This commit is contained in:
Lianmin Zheng
2024-09-18 03:45:19 -07:00
committed by GitHub
parent aa2750beb3
commit 1acccb364a
8 changed files with 33 additions and 21 deletions

View File

@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = self.param_dict
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: