Fix oom issues with fp8 for llama (#1454)
This commit is contained in:
@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
|
||||
return logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
|
||||
Reference in New Issue
Block a user