replace skip_embed with input_embeds (#222)
This commit is contained in:
@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not skip_embed:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = input_ids
|
hidden_states = input_embeds
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
pt += 1
|
pt += 1
|
||||||
|
|
||||||
return self.language_model(
|
return self.language_model(
|
||||||
input_embeds, positions, input_metadata, skip_embed=True
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
)
|
)
|
||||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
return self.language_model(
|
return self.language_model(
|
||||||
input_ids, positions, input_metadata, skip_embed=False
|
input_ids, positions, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
|
|||||||
@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not skip_embed:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = input_ids
|
hidden_states = input_embeds
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not skip_embed:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = input_ids
|
hidden_states = input_embeds
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
skip_embed: bool = False,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user