Fix model forward grad (#628)

This commit is contained in:
Liangsheng Yin
2024-07-15 22:09:09 -07:00
committed by GitHub
parent 41d1f67704
commit c9ee3d3559
14 changed files with 14 additions and 0 deletions

View File

@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,