refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"):
return self.embed_tokens(input_ids) * self.config.scale_emb
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
else:
return self.embed_tokens(input_ids)
return self.get_input_embeddings()(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def forward(
self,
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
def get_input_embedding(self) -> nn.Embedding:
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()