diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index fdf6276c1..e5c28fa12 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -227,12 +227,12 @@ class LlamaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c35b9c7eb..b615f5953 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module): pt += 1 return self.language_model( - input_embeds, positions, input_metadata, skip_embed=True + input_ids, positions, input_metadata, input_embeds=input_embeds ) elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model( - input_ids, positions, input_metadata, skip_embed=False + input_ids, positions, input_metadata ) def load_weights( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index ea9e99bb0..01a830807 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -296,12 +296,12 @@ class MixtralModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index e77d4082e..26f0a5ae1 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -228,12 +228,12 @@ class Qwen2Model(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata )