初始化项目,由ModelHub XC社区提供模型
Model: Collective-Ai/collective-v0.1-chinese-roleplay-8b Source: Original Platform
This commit is contained in:
66
modeling_llama3.py
Normal file
66
modeling_llama3.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# coding=utf-8
|
||||
import torch
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
|
||||
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self.input_length = input_length
|
||||
self.penalty_dialog = penalty_dialog
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
new_scores = []
|
||||
if self.penalty == 0.0:
|
||||
return scores
|
||||
for input_, score in zip(input_ids, scores):
|
||||
generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1)
|
||||
token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device)
|
||||
new_scores.append(score - self.penalty * token_frequency)
|
||||
|
||||
return torch.stack(new_scores).float()
|
||||
|
||||
|
||||
|
||||
class LlamaForConditionalGeneration(LlamaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
history_penalty = kwargs.pop("history_penalty", 0.0)
|
||||
penalty_turns = kwargs.pop("penalty_turns", 0)
|
||||
messages = kwargs.pop("messages", [])
|
||||
|
||||
if history_penalty != 0.0 and penalty_turns >= 0:
|
||||
input_ids = kwargs.get("input_ids", torch.tensor([[]]))
|
||||
input_length = input_ids.size(-1)
|
||||
|
||||
dialogs = []
|
||||
for i in range(len(messages)):
|
||||
message = messages[i]
|
||||
if message['role'] == 'assistant':
|
||||
dialogs.append(message['content'])
|
||||
|
||||
penalty_dialog = []
|
||||
for i in range(penalty_turns, 0, -1):
|
||||
if i <= len(dialogs):
|
||||
dialog = dialogs[-i].replace("("," ").replace(")"," ").replace("("," ").replace(")"," ")
|
||||
penalty_dialog.append(dialog)
|
||||
|
||||
model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device)
|
||||
|
||||
logits_processor = []
|
||||
logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length))
|
||||
result = super().generate(logits_processor = logits_processor, **kwargs)
|
||||
else:
|
||||
result = super().generate(**kwargs)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user