Fix model forward grad (#628)

This commit is contained in:
Liangsheng Yin
2024-07-15 22:09:09 -07:00
committed by GitHub
parent 41d1f67704
commit c9ee3d3559
14 changed files with 14 additions and 0 deletions

View File

@@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,

View File

@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,

View File

@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,