Fix model forward grad (#628)

This commit is contained in:
Liangsheng Yin
2024-07-15 22:09:09 -07:00
committed by GitHub
parent 41d1f67704
commit c9ee3d3559
14 changed files with 14 additions and 0 deletions

View File

@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,