replace skip_embed with input_embeds (#222)
This commit is contained in:
@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_embeds, positions, input_metadata, skip_embed=True
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, skip_embed=False
|
||||
input_ids, positions, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
|
||||
Reference in New Issue
Block a user