replace skip_embed with input_embeds (#222)

This commit is contained in:
Geary.Z
2024-03-11 10:04:52 +08:00
committed by GitHub
parent a7ace9c88d
commit 64fe311593
4 changed files with 17 additions and 17 deletions

View File

@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if not skip_embed:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_ids
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)