Files
collective-v0.1-chinese-rol…/modeling_llama3.py
ModelHub XC 3dbe8b0e86 初始化项目,由ModelHub XC社区提供模型
Model: Collective-Ai/collective-v0.1-chinese-roleplay-8b
Source: Original Platform
2026-05-25 06:38:17 +08:00

67 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
import torch
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer
import re
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.input_length = input_length
self.penalty_dialog = penalty_dialog
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_scores = []
if self.penalty == 0.0:
return scores
for input_, score in zip(input_ids, scores):
generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1)
token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device)
new_scores.append(score - self.penalty * token_frequency)
return torch.stack(new_scores).float()
class LlamaForConditionalGeneration(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
def generate(self, **kwargs):
history_penalty = kwargs.pop("history_penalty", 0.0)
penalty_turns = kwargs.pop("penalty_turns", 0)
messages = kwargs.pop("messages", [])
if history_penalty != 0.0 and penalty_turns >= 0:
input_ids = kwargs.get("input_ids", torch.tensor([[]]))
input_length = input_ids.size(-1)
dialogs = []
for i in range(len(messages)):
message = messages[i]
if message['role'] == 'assistant':
dialogs.append(message['content'])
penalty_dialog = []
for i in range(penalty_turns, 0, -1):
if i <= len(dialogs):
dialog = dialogs[-i].replace("("," ").replace(")"," ").replace(""," ").replace(""," ")
penalty_dialog.append(dialog)
model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device)
logits_processor = []
logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length))
result = super().generate(logits_processor = logits_processor, **kwargs)
else:
result = super().generate(**kwargs)
return result