[Fix] Fix llava on multi images (#1247)
This commit is contained in:
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Grok1ModelForCausalLM(nn.Module):
|
||||
class Grok1ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
EntryClass = Grok1ModelForCausalLM
|
||||
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
||||
"""An alias for backward-compatbility."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|
||||
|
||||
Reference in New Issue
Block a user