forked from EngineX-Cambricon/enginex-mlu370-vllm
141 lines
5.5 KiB
Python
141 lines
5.5 KiB
Python
import torch
|
|
|
|
|
|
def make_context(
|
|
tokenizer,
|
|
query,
|
|
history,
|
|
system,
|
|
max_input_length,
|
|
max_window_size: int = 6144,
|
|
chat_format: str = "chatml",
|
|
):
|
|
'''
|
|
tokenize one text context to tokenized id
|
|
args:
|
|
tokenizer: model tokenizer
|
|
query: current text context
|
|
history: history text context
|
|
system: system prompt
|
|
max_input_length: max input length of tokenized id
|
|
chat_format: chat format, only accept chatml and raw
|
|
'''
|
|
if history is None:
|
|
history = []
|
|
|
|
if chat_format == "chatml":
|
|
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
|
im_start_tokens = [tokenizer.im_start_id]
|
|
im_end_tokens = [tokenizer.im_end_id]
|
|
nl_tokens = tokenizer.encode("\n")
|
|
|
|
def _tokenize_str(role, content):
|
|
'''
|
|
tokensize string
|
|
'''
|
|
return (f"{role}\n{content}", tokenizer.encode(
|
|
role,
|
|
allowed_special=set(),
|
|
) + nl_tokens + tokenizer.encode(
|
|
content,
|
|
allowed_special=set(),
|
|
))
|
|
|
|
system_text, system_tokens_part = _tokenize_str("system", system)
|
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
|
raw_text = ""
|
|
context_tokens = []
|
|
|
|
for turn_query, turn_response in reversed(history):
|
|
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
|
|
|
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
|
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
|
prev_chat = (f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}")
|
|
|
|
current_context_size = (len(system_tokens) + len(next_context_tokens) + len(context_tokens))
|
|
if current_context_size < max_window_size:
|
|
context_tokens = next_context_tokens + context_tokens
|
|
raw_text = prev_chat + raw_text
|
|
else:
|
|
break
|
|
|
|
context_tokens = system_tokens + context_tokens
|
|
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
|
context_tokens += (nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens +
|
|
im_start_tokens + tokenizer.encode("assistant") + nl_tokens)
|
|
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
|
|
|
elif chat_format == "raw":
|
|
raw_text = query
|
|
context_tokens = tokenizer.encode(raw_text)
|
|
else:
|
|
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
# truncate to max_input_length, truncate from the front
|
|
return raw_text, context_tokens[-max_input_length:]
|
|
|
|
|
|
def prepare_inputs(batch_input_texts,
|
|
tokenizer,
|
|
model_name,
|
|
model_version,
|
|
test_token_num,
|
|
eval_task='summarize',
|
|
add_special_tokens=True):
|
|
'''
|
|
tokenize batch input texts into tokenized id.
|
|
args:
|
|
batch_input_texts: batch input text, also named batched prompt
|
|
tokenizer: model tokenizer
|
|
model_name: model name
|
|
model_version: model version
|
|
test_token_num: batch size, also named prompt number
|
|
eval_task: eval task
|
|
add_special_tokens: whether to add_special_tokens, default True
|
|
'''
|
|
batch_size = len(batch_input_texts)
|
|
append_str = ' TL;DR: ' if eval_task == 'summarize' else ''
|
|
batch_input_ids = []
|
|
for i in range(batch_size):
|
|
curr_text = batch_input_texts[i] + append_str
|
|
curr_text = curr_text.strip().replace(" n't", "n't")
|
|
|
|
# The below lines are used to be compatible with the original code
|
|
if 'GLM' in model_name and model_version in ['chatglm2', 'chatglm3']:
|
|
input_ids = tokenizer.encode(curr_text, return_tensors='pt').squeeze(0)
|
|
input_ids = input_ids[:test_token_num]
|
|
elif 'qwen' in model_name.lower() and model_version == 'qwen':
|
|
# use make_content to generate prompt
|
|
system_prompt = "You are a useful assistant, please directly output the corresponding " + \
|
|
"summary according to the article entered by the user."
|
|
_, input_id_list = make_context(
|
|
tokenizer=tokenizer,
|
|
query=curr_text,
|
|
history=[],
|
|
system=system_prompt,
|
|
max_input_length=test_token_num,
|
|
)
|
|
input_ids = torch.tensor(input_id_list)
|
|
else:
|
|
if 'qwen' in model_name.lower() and 'qwen2' in model_version:
|
|
messages = [{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
}, {
|
|
"role": "user",
|
|
"content": curr_text
|
|
}]
|
|
curr_text = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
input_ids = tokenizer.encode(curr_text,
|
|
return_tensors='pt',
|
|
add_special_tokens=add_special_tokens,
|
|
truncation=True,
|
|
max_length=test_token_num).squeeze(0)
|
|
|
|
batch_input_ids.append(input_ids)
|
|
return batch_input_ids
|