From c9ee3d3559717dd7a92616315b1f997dd6ba7acc Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 15 Jul 2024 22:09:09 -0700 Subject: [PATCH] Fix model forward grad (#628) --- python/sglang/srt/models/chatglm.py | 1 + python/sglang/srt/models/dbrx.py | 1 + python/sglang/srt/models/grok.py | 1 + python/sglang/srt/models/llama2.py | 1 + python/sglang/srt/models/llama_classification.py | 1 + python/sglang/srt/models/llava.py | 1 + python/sglang/srt/models/llavavid.py | 1 + python/sglang/srt/models/minicpm.py | 1 + python/sglang/srt/models/mixtral.py | 1 + python/sglang/srt/models/mixtral_quant.py | 1 + python/sglang/srt/models/qwen.py | 1 + python/sglang/srt/models/qwen2.py | 1 + python/sglang/srt/models/qwen2_moe.py | 1 + python/sglang/srt/models/stablelm.py | 1 + 14 files changed, 14 insertions(+) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index e9ec3e2d2..3c6ea4521 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.sampler = Sampler() + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b21142d2e..a6f039217 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module): ) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index cbf29055c..3a765b6a7 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module): # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 95ba71ee9..eca15c7cb 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index eb9dde45c..96b1ac01e 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module): ) self.eos_token_id = config.eos_token_id + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 915c9bee0..07c8c7372 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module): return image_features + @torch.no_grad() def forward( self, input_ids: torch.LongTensor, diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 47e20583c..33e1f3f30 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module): return image_features + @torch.no_grad() def forward( self, input_ids: torch.LongTensor, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 3f16c95f9..347404b2e 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index abcde6de5..19fc50162 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index aa8f8a759..69da0a1c4 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 9c59d14fe..ba098b8c5 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module): self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index dc50075ca..83b7d2f9c 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 072002c6f..6e90babcc 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.sampler = Sampler() + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 875ddd70b..4b91d2c4d 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor,