[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
hidden_states.mul_(self.config.embedding_multiplier_scale)
else:
hidden_states = input_embeds
hidden_states.mul_(self.config.embedding_multiplier_scale)
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
return hidden_states
class Grok1ModelForCausalLM(nn.Module):
class Grok1ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
return hf_folder, hf_weights_files, use_safetensors
EntryClass = Grok1ModelForCausalLM
class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
pass
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]