Fix model forward grad (#628)
This commit is contained in:
@@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
|
|||||||
)
|
)
|
||||||
self.eos_token_id = config.eos_token_id
|
self.eos_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
|||||||
@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user