Files
2026-02-04 17:22:39 +08:00

141 lines
5.5 KiB
Python

import torch
def make_context(
tokenizer,
query,
history,
system,
max_input_length,
max_window_size: int = 6144,
chat_format: str = "chatml",
):
'''
tokenize one text context to tokenized id
args:
tokenizer: model tokenizer
query: current text context
history: history text context
system: system prompt
max_input_length: max input length of tokenized id
chat_format: chat format, only accept chatml and raw
'''
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
'''
tokensize string
'''
return (f"{role}\n{content}", tokenizer.encode(
role,
allowed_special=set(),
) + nl_tokens + tokenizer.encode(
content,
allowed_special=set(),
))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}")
current_context_size = (len(system_tokens) + len(next_context_tokens) + len(context_tokens))
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens +
im_start_tokens + tokenizer.encode("assistant") + nl_tokens)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
# truncate to max_input_length, truncate from the front
return raw_text, context_tokens[-max_input_length:]
def prepare_inputs(batch_input_texts,
tokenizer,
model_name,
model_version,
test_token_num,
eval_task='summarize',
add_special_tokens=True):
'''
tokenize batch input texts into tokenized id.
args:
batch_input_texts: batch input text, also named batched prompt
tokenizer: model tokenizer
model_name: model name
model_version: model version
test_token_num: batch size, also named prompt number
eval_task: eval task
add_special_tokens: whether to add_special_tokens, default True
'''
batch_size = len(batch_input_texts)
append_str = ' TL;DR: ' if eval_task == 'summarize' else ''
batch_input_ids = []
for i in range(batch_size):
curr_text = batch_input_texts[i] + append_str
curr_text = curr_text.strip().replace(" n't", "n't")
# The below lines are used to be compatible with the original code
if 'GLM' in model_name and model_version in ['chatglm2', 'chatglm3']:
input_ids = tokenizer.encode(curr_text, return_tensors='pt').squeeze(0)
input_ids = input_ids[:test_token_num]
elif 'qwen' in model_name.lower() and model_version == 'qwen':
# use make_content to generate prompt
system_prompt = "You are a useful assistant, please directly output the corresponding " + \
"summary according to the article entered by the user."
_, input_id_list = make_context(
tokenizer=tokenizer,
query=curr_text,
history=[],
system=system_prompt,
max_input_length=test_token_num,
)
input_ids = torch.tensor(input_id_list)
else:
if 'qwen' in model_name.lower() and 'qwen2' in model_version:
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": curr_text
}]
curr_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer.encode(curr_text,
return_tensors='pt',
add_special_tokens=add_special_tokens,
truncation=True,
max_length=test_token_num).squeeze(0)
batch_input_ids.append(input_ids)
return batch_input_ids