Fix model forward grad (#628)

This commit is contained in:
Liangsheng Yin
2024-07-15 22:09:09 -07:00
committed by GitHub
parent 41d1f67704
commit c9ee3d3559
14 changed files with 14 additions and 0 deletions

View File

@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,