Adjust reward model's score module and pooler module order for reducing computation (#1956)
This commit is contained in:
@@ -58,43 +58,10 @@ class Gemma2ForSequenceClassification(nn.Module):
|
|||||||
), "Gemma2ForSequenceClassification is only used for embedding"
|
), "Gemma2ForSequenceClassification is only used for embedding"
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
scores = self.score(hidden_states)
|
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||||
|
scores = self.score(last_token_hidden)
|
||||||
|
|
||||||
return self.pooler(scores, forward_batch)
|
return EmbeddingPoolerOutput(scores)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
||||||
stacked_params_mapping = [
|
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
||||||
if shard_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(shard_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# lm_head is not used in vllm as it is tied with embed_token.
|
|
||||||
# To prevent errors, skip loading lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
Gemma2ForCausalLM.load_weights(self, weights)
|
Gemma2ForCausalLM.load_weights(self, weights)
|
||||||
|
|||||||
@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module):
|
|||||||
), "LlamaForSequenceClassification is only used for embedding"
|
), "LlamaForSequenceClassification is only used for embedding"
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
scores = self.score(hidden_states)
|
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||||
|
scores = self.score(last_token_hidden)
|
||||||
|
|
||||||
return self.pooler(scores, forward_batch)
|
return EmbeddingPoolerOutput(scores)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters())
|
return LlamaForCausalLM.load_weights(self, weights)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "classification_head" in name:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
elif "lm_head" in name:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
|
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
|
||||||
@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
|||||||
return EmbeddingPoolerOutput(scores)
|
return EmbeddingPoolerOutput(scores)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters())
|
return super().load_weights(weights)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "classification_head" in name:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
elif "lm_head" in name:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [
|
EntryClass = [
|
||||||
|
|||||||
Reference in New Issue
Block a user