From 64fe311593edee917a28506be8723127d4e938c9 Mon Sep 17 00:00:00 2001 From: "Geary.Z" <92413813+TideDra@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:04:52 +0800 Subject: [PATCH] replace skip_embed with input_embeds (#222) --- python/sglang/srt/models/llama2.py | 10 +++++----- python/sglang/srt/models/llava.py | 4 ++-- python/sglang/srt/models/mixtral.py | 10 +++++----- python/sglang/srt/models/qwen2.py | 10 +++++----- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index fdf6276c1..e5c28fa12 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -227,12 +227,12 @@ class LlamaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c35b9c7eb..b615f5953 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module): pt += 1 return self.language_model( - input_embeds, positions, input_metadata, skip_embed=True + input_ids, positions, input_metadata, input_embeds=input_embeds ) elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model( - input_ids, positions, input_metadata, skip_embed=False + input_ids, positions, input_metadata ) def load_weights( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index ea9e99bb0..01a830807 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -296,12 +296,12 @@ class MixtralModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index e77d4082e..26f0a5ae1 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -228,12 +228,12 @@ class Qwen2Model(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata )