Files
HuatuoGPT2-7B/generation_utils.py
ModelHub XC 473fd0a2b6 初始化项目,由ModelHub XC社区提供模型
Model: FreedomIntelligence/HuatuoGPT2-7B
Source: Original Platform
2026-05-04 17:12:00 +08:00

132 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
from queue import Queue
import torch
# def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
# def _parse_messages(messages, split_role="user"):
# system, rounds = "", []
# round = []
# for i, message in enumerate(messages):
# if message["role"] == "system":
# assert i == 0
# system = message["content"]
# continue
# if message["role"] == split_role and round:
# rounds.append(round)
# round = []
# round.append(message)
# if round:
# rounds.append(round)
# return system, rounds
# max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
# max_input_tokens = model.config.model_max_length - max_new_tokens
# system, rounds = _parse_messages(messages, split_role="user")
# system_tokens = tokenizer.encode(system)
# max_history_tokens = max_input_tokens - len(system_tokens)
# history_tokens = []
# for round in rounds[::-1]:
# round_tokens = []
# for message in round:
# if message["role"] == "user":
# round_tokens.append(model.generation_config.user_token_id)
# else:
# round_tokens.append(model.generation_config.assistant_token_id)
# round_tokens.extend(tokenizer.encode(message["content"]))
# if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
# history_tokens = round_tokens + history_tokens # concat left
# if len(history_tokens) < max_history_tokens:
# continue
# break
# input_tokens = system_tokens + history_tokens
# if messages[-1]["role"] != "assistant":
# input_tokens.append(model.generation_config.assistant_token_id)
# input_tokens = input_tokens[-max_input_tokens:] # truncate left
# return torch.LongTensor([input_tokens]).to(model.device)
# for HuatuoGPT2
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
def _parse_messages(messages, split_role="user"):
system, rounds = "", []
round = []
for i, message in enumerate(messages):
# if message["role"] == "system":
# assert i == 0
# system = message["content"]
# continue
if message["role"] == split_role and round:
rounds.append(round)
round = []
round.append(message)
if round:
rounds.append(round)
return system, rounds
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
max_input_tokens = model.config.model_max_length - max_new_tokens
system, rounds = _parse_messages(messages, split_role="user")
max_history_tokens = max_input_tokens
roles = ('<问>','<答>')
sep = '\n'
history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
message["content"]
if message["role"] == "user":
round_tokens.extend(tokenizer.encode(roles[0]+message["content"]+sep))
else:
round_tokens.extend(tokenizer.encode(roles[1]+message["content"]+sep))
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = history_tokens
if messages[-1]["role"] != "assistant":
input_tokens.extend(tokenizer.encode(roles[1]))
# debug
input_tokens = input_tokens[-max_input_tokens:] # truncate left
# print(tokenizer.decode(input_tokens),flush=True)
return torch.LongTensor([input_tokens]).to(model.device)
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value