Adjust reward model's score module and pooler module order for reducing computation (#1956)

This commit is contained in:
aqweteddy
2024-11-08 16:10:54 +08:00
committed by GitHub
parent 8dc84da084
commit 4ade15dd32
2 changed files with 8 additions and 60 deletions

View File

@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module):
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.score(last_token_hidden)
return self.pooler(scores, forward_batch)
return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
return LlamaForCausalLM.load_weights(self, weights)
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
return super().load_weights(weights)
EntryClass = [