replace skip_embed with input_embeds (#222)

This commit is contained in:
Geary.Z
2024-03-11 10:04:52 +08:00
committed by GitHub
parent a7ace9c88d
commit 64fe311593
4 changed files with 17 additions and 17 deletions

View File

@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
pt += 1
return self.language_model(
input_embeds, positions, input_metadata, skip_embed=True
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(
input_ids, positions, input_metadata, skip_embed=False
input_ids, positions, input_metadata
)
def load_weights(