Fix model forward grad (#628)

This commit is contained in:
Liangsheng Yin
2024-07-15 22:09:09 -07:00
committed by GitHub
parent 41d1f67704
commit c9ee3d3559
14 changed files with 14 additions and 0 deletions

View File

@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,