Files
Ring-lite-2506/tokenization_bailing.py
ModelHub XC 1105c974d0 初始化项目,由ModelHub XC社区提供模型
Model: inclusionAI/Ring-lite-2506
Source: Original Platform
2026-04-11 12:14:58 +08:00

1069 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# coding=utf-8
# Copyright (c) Ant Group. All rights reserved.
import itertools
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizerFast
from transformers.tokenization_utils_base import AddedToken, BatchEncoding
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def is_system(msg):
return msg['role'].lower() == 'system'
def is_user(msg):
return msg['role'].lower() in ['human', 'user']
def is_assistant(msg):
return msg['role'].lower() == 'assistant'
def _convert_to_conversation(query, system=None):
conversation = []
if system:
conversation.append({"role": "SYSTEM", "content": system})
if isinstance(query, str):
conversation.append({"role": "HUMAN", "content": query})
elif isinstance(query, List):
conversation.extend(query)
elif isinstance(query, Dict):
if "messages" in query:
conversation.extend(query["messages"])
if "system_message" in query and len(conversation) > 0 and not is_system(conversation[0]):
conversation.insert(0, {"role": "SYSTEM", "content": query["system_message"]})
else:
conversation.append(query)
return conversation
class BailingTokenizer(PreTrainedTokenizerFast):
is_bailing_tokenizer = True
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = None
# add gmask_token
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"gmask_token",
"additional_special_tokens",
]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
cls_token="[CLS]",
pad_token="<|endoftext|>",
gmask_token="[gMASK]",
add_bos_token=False,
add_eos_token=False,
**kwargs,
):
self.add_bos_token = add_bos_token
self._gmask_token = (
AddedToken(gmask_token, lstrip=False, rstrip=False, normalized=False)
if isinstance(gmask_token, str)
else gmask_token
)
self._sop_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
self._eop_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
bos_token=bos_token,
eos_token=eos_token,
cls_token=cls_token,
pad_token=pad_token,
gmask_token=gmask_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
**kwargs,
)
self.check_special_tokens()
def check_special_tokens(self):
'''
eos_token, cls_token, mask_token
special tokens should init, check special token is not None
'''
for name, special_token in zip(
['eos', 'bos', 'cls', 'gmask'],
[self.eos_token, self.bos_token, self.cls_token, self.gmask_token],
):
assert special_token is not None, f'should init special token [{name}] in tokenizer_config.json'
@property
def gmask_token(self) -> Optional[str]:
if self._gmask_token is None:
if self.verbose:
logger.error("Using gmask_token, but it is not set yet.")
return None
return str(self._gmask_token)
@gmask_token.setter
def gmask_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the gmask token")
self._gmask_token = value
@property
def gmask_token_id(self) -> Optional[int]:
if self._gmask_token is None:
return None
return self.convert_tokens_to_ids(self.gmask_token)
@property
def sop_token(self) -> Optional[str]:
if self._sop_token is None:
if self.verbose:
logger.error("Using sop_token, but it is not set yet.")
return None
return str(self._sop_token)
@sop_token.setter
def sop_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the sop token")
self._sop_token = value
@property
def sop_token_id(self) -> Optional[int]:
if self._sop_token is None:
return None
return self.convert_tokens_to_ids(self.sop_token)
@property
def eop_token(self) -> Optional[str]:
if self._eop_token is None:
if self.verbose:
logger.error("Using eop_token, but it is not set yet.")
return None
return str(self._eop_token)
@eop_token.setter
def eop_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the eop token")
self._eop_token = value
@property
def eop_token_id(self) -> Optional[int]:
if self._eop_token is None:
return None
return self.convert_tokens_to_ids(self.eop_token)
@property
def vocab_size(self):
return len(self.get_vocab())
def _chat_from_json(self, chat, chat_format="antglm_chat", system=None):
msgs = chat if "messages" not in chat else chat["messages"]
_msgs = []
sys_msg = None
for msg in msgs:
if is_system(msg):
sys_msg = msg['content']
else:
_msgs.append(msg)
chat = {"messages": _msgs}
system = system or sys_msg
if system:
chat['system_message'] = system
from .chat_format import Chat
return Chat.from_json(chat, name=chat_format)
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
tools: Optional[List[Dict]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
system: str = None, # only used for legacy chatml
tokenize=False,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
return_assistant_tokens_mask: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
if hasattr(self, "chat_template") and self.chat_template:
if isinstance(conversation, Dict) and "messages" in conversation:
conversation = conversation["messages"]
# use transformers built-in method
return super().apply_chat_template(
conversation=conversation,
tools=tools,
documents=documents,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
padding=padding,
truncation=truncation,
return_tensors=return_tensors,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
tokenizer_kwargs=tokenizer_kwargs,
)
# 非chat_template方式后续将不再支持。
logger.warning("Please set chat_template in tokenizer_config.json!")
chat_format = kwargs.get('chat_format', 'antglm_chat')
is_batched = False
if isinstance(conversation, List) and (
isinstance(conversation[0], (list, tuple)) or "messages" in conversation[0]
):
conversations = conversation
is_batched = True
if not is_batched:
conversations = [conversation]
rendered = []
for chat in conversations:
rendered_chat = self._chat_from_json(chat, chat_format=chat_format, system=system).prompt_str
rendered.append(rendered_chat)
if not is_batched:
rendered = rendered[0]
if tokenize:
out = self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
)
if return_dict:
return out
else:
return out["input_ids"]
else:
return rendered
def _build_position_ids(
self,
mask_pos: int,
bos_pos: int,
max_output_length: int,
rotary_type: Optional[str] = "none",
**kwargs,
) -> List[List[int]]:
window_size = kwargs.get("window_size", 1024) - 1
block_position_ids = [0] * bos_pos
# 获得mask所在的位置用于后面output positionid的构造
if "1d" in rotary_type:
position_ids = list(range(bos_pos)) + list(range(mask_pos + 1, mask_pos + max_output_length + 2))
block_position_ids = block_position_ids + list(range(1, max_output_length + 2))
elif "2d" in rotary_type:
# 后面input_ids要加一个bos_id
position_ids = list(range(bos_pos))
position_ids = position_ids + [mask_pos] * (1 + max_output_length)
block_position_ids = block_position_ids + list(range(1, max_output_length + 2))
else:
# build position ids
position_ids = []
repeat_times = bos_pos // window_size
for _ in range(repeat_times):
position_ids += list(range(window_size))
position_ids += list(range(bos_pos - window_size * repeat_times))
# need consider additional bos_id after input_ids
mask_pos = position_ids[-1]
position_ids += [mask_pos] * (max_output_length + 1)
block_repeat_times = max_output_length // (window_size - 1)
additional_block_position_ids = []
for _ in range(block_repeat_times):
additional_block_position_ids += list(range(1, window_size))
additional_block_position_ids += list(
range(1, max_output_length + 2 - (window_size - 1) * block_repeat_times)
)
block_position_ids = block_position_ids + additional_block_position_ids
position_ids = [position_ids, block_position_ids]
return position_ids
def _build_inputs_for_generation(
self,
input_ids: List[int],
max_input_length=None,
left_truncate=True,
max_output_length=1024,
rotary_type="none",
unidirectional_attention: bool = True,
attention_dtype=None,
**kwargs,
):
if max_input_length and len(input_ids) > max_input_length:
if left_truncate:
input_ids = input_ids[-max_input_length:]
else:
input_ids = input_ids[:max_input_length]
is_left_padding = input_ids[0] == self.eos_token_id
if not unidirectional_attention:
if input_ids[0] != self.cls_token_id:
input_ids = [self.cls_token_id] + input_ids
if self.gmask_token_id not in set(input_ids):
input_ids = input_ids + [self.gmask_token_id]
mask_pos = input_ids.index(self.gmask_token_id)
sep = len(input_ids)
else:
if self.add_bos_token:
input_ids = input_ids + [self.bos_token_id]
if self.eos_token_id in input_ids:
mask_pos = input_ids.index(self.eos_token_id) - 1
else:
mask_pos = len(input_ids) - 1
sep = len(input_ids) - 1
else:
sep = len(input_ids)
if self.eos_token_id in input_ids:
if is_left_padding:
ori_input_ids = input_ids
input_ids = input_ids[::-1]
mask_pos = input_ids.index(self.eos_token_id) - 1
mask_pos = max(0, mask_pos) # for empty sequence
if is_left_padding:
input_ids = ori_input_ids
mask_pos = sep - 1 - mask_pos # the first non-eos token
else:
mask_pos = len(input_ids) - 1
position_ids = self._build_position_ids(mask_pos, sep, max_output_length, rotary_type, **kwargs)
if is_left_padding:
position_ids[0] = [max(0, i - mask_pos) for i in range(len(position_ids[0]))]
# 后面input_ids要加一个bos_id
total_length = sep + max_output_length
if self.add_bos_token:
total_length += 1
def build_mask_matrix(seq_length, sep, mask_pos, unidirectional_attention):
# 长序列使用bool类型节省显存
if unidirectional_attention:
attention_mask = torch.ones([seq_length, seq_length], dtype=attention_dtype)
attention_mask = torch.tril(attention_mask)
if is_left_padding:
attention_mask[:, :mask_pos] = 0
else:
attention_mask[:, mask_pos + 1 : sep] = 0
else:
attention_mask = torch.zeros([seq_length, seq_length], dtype=attention_dtype)
attention_mask[:, : mask_pos + 1] = 1
for i in range(sep, total_length):
attention_mask[i, sep : i + 1] = 1
return attention_mask
if self.add_bos_token:
attention_mask = build_mask_matrix(total_length, sep + 1, mask_pos, unidirectional_attention)
else:
attention_mask = build_mask_matrix(total_length, sep, mask_pos, unidirectional_attention)
attention_mask = torch.unsqueeze(attention_mask, dim=0)
attention_mask = torch.unsqueeze(attention_mask, dim=1)
if attention_dtype is None:
attention_mask = attention_mask.long()
inputs = {
"input_ids": torch.Tensor([input_ids]).long(),
"position_ids": torch.Tensor([position_ids]).long(),
"attention_mask": attention_mask,
}
return BatchEncoding(inputs)
def build_inputs_for_generation(
self,
input_ids: Union[List[int], List[List[int]], torch.Tensor],
max_input_length=None,
left_truncate=True,
max_output_length=1024,
rotary_type="1d",
unidirectional_attention=True,
attention_dtype=None,
**kwargs,
):
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.tolist()
if isinstance(input_ids[0], list):
input_ids_list = []
position_ids_list = []
attention_mask_list = []
for _input_ids in input_ids:
inputs = self._build_inputs_for_generation(
_input_ids,
max_input_length=max_input_length,
left_truncate=left_truncate,
max_output_length=max_output_length,
rotary_type=rotary_type,
unidirectional_attention=unidirectional_attention,
attention_dtype=attention_dtype,
**kwargs,
)
input_ids_list.append(inputs['input_ids'])
position_ids_list.append(inputs['position_ids'])
attention_mask_list.append(inputs["attention_mask"])
max_ids_length = max([input.size(1) for input in input_ids_list])
for i in range(len(input_ids)):
cur_ids_length = input_ids_list[i].size(1)
if cur_ids_length < max_ids_length:
# pad input ids
pad_input_ids = input_ids_list[i].new_zeros((1, max_ids_length - cur_ids_length))
input_ids_list[i] = torch.cat([pad_input_ids, input_ids_list[i]], dim=-1)
# pad postition ids with left pad
# 0, 1, 2, 3, 4 ... -> 0, ..., 0, 1, 2, 3, 4, ...
pad_position_ids = input_ids_list[i].new_zeros((1, 2, max_ids_length - cur_ids_length))
position_ids_list[i] = torch.cat([pad_position_ids, position_ids_list[i]], dim=-1)
# pad generation attention mask with left and bottom pad
new_attention_mask = input_ids_list[i].new_zeros(
1,
1,
max_ids_length + max_output_length,
max_ids_length + max_output_length,
)
new_attention_mask[
:,
:,
max_ids_length - cur_ids_length :,
max_ids_length - cur_ids_length :,
] = attention_mask_list[i]
attention_mask_list[i] = new_attention_mask.contiguous()
input_ids_list = torch.cat(input_ids_list, dim=0)
position_ids_list = torch.cat(position_ids_list, dim=0)
attention_mask_list = torch.cat(attention_mask_list, dim=0)
inputs = {
"input_ids": input_ids_list,
"position_ids": position_ids_list,
"attention_mask": attention_mask_list,
}
return BatchEncoding(inputs)
else:
return self._build_inputs_for_generation(
input_ids,
max_input_length=max_input_length,
left_truncate=left_truncate,
max_output_length=max_output_length,
rotary_type=rotary_type,
unidirectional_attention=unidirectional_attention,
**kwargs,
)
def _build_inputs_for_train(
self,
inputs: Union[str, List[str]],
outputs: Union[str, List[str]],
new_conversation_offset: List[int] = None,
max_length: int = 2048,
rotary_type: str = "1d",
left_truncate: bool = True,
unidirectional_attention: bool = True,
isolation_position_ids: bool = False,
padding: bool = True,
use_fa2: bool = True,
use_packed: bool = True,
use_baichuan_packed: bool = False,
skip_truncated_turn: bool = False,
return_attention_mask: bool = True,
):
r"""
Build tensor input for model training. If inputs and outputs are list, will pack them.
Args:
inputs (str, List[str], List[Dict], List[List[Dict]]): the input prompts.
outputs (str, List[str]): the output responses.
max_length (int, Optional): the maximum length of the final input ids for training. Default: 2048
rotary_type (str, Optional): the rotary type of position embedding. Default: 1d
left_truncate (bool, Optional): whether truncate the inputs from left. Default: True
use_fa2 (bool, Optional): whether to build attention mask under flash attention 2.
new_conversation_offset (List[int], Optional): 第idx条样本是全新的对话[0, 1]代表inputs[0]和outputs[0]是一个对话inputs[1]和outputs[1]是一个对话.
"""
if use_packed and use_baichuan_packed and unidirectional_attention:
return self._build_baichuan_inputs_for_train(
inputs,
outputs,
new_conversation_offset,
max_length,
rotary_type,
left_truncate,
skip_truncated_turn,
use_fa2,
padding,
)
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(outputs, str):
outputs = [outputs]
assert len(inputs) == len(outputs)
input_ids = [self(item)['input_ids'] for item in inputs]
output_ids = [self(item)['input_ids'] for item in outputs]
packed_input_ids = []
packed_output_ids = []
if new_conversation_offset is None:
new_conversation_offset = list(range(0, len(inputs)))
assert 0 in new_conversation_offset, f"没有0请检查new_conversation_offset: {new_conversation_offset}"
current_len = 0
for idx, (input, output) in enumerate(zip(input_ids, output_ids)):
num_special_tokens = 0
if not unidirectional_attention:
if idx in new_conversation_offset:
# cls and gmask
num_special_tokens += 2
else:
# only gmask
num_special_tokens += 1
else:
# sop and eos
if self.add_bos_token:
num_special_tokens += 2
else:
num_special_tokens += 1
# truncate
if len(input) + len(output) + current_len > max_length - num_special_tokens:
if not use_packed or use_fa2 and unidirectional_attention:
attention_mask = torch.tensor(0)
elif use_fa2:
attention_mask = -1 * torch.ones([2, max_length])
else:
attention_mask = torch.tril(torch.ones([max_length, max_length]))
# 返回一个空的样本,该样本不参与训练
default_return = {
'input_ids': (torch.ones(max_length) * self.eos_token_id).long(),
'position_ids': torch.zeros(2, max_length).long(),
'attention_mask': (attention_mask.long()),
'labels': (torch.ones(max_length) * -100).long(),
}
# 如果不截断,直接返回
if skip_truncated_turn:
if current_len == 0:
return default_return
else:
break
left_len = max_length - num_special_tokens - current_len
# 如果截断只截断prompt
if left_len - len(output) > 0:
if left_truncate:
input = input[-(left_len - len(output)) :]
else:
input = input[: left_len - len(output)]
else:
# response超过left_len直接返回
if current_len == 0:
return default_return
else:
break
if unidirectional_attention:
packed_input_ids.append(list(input))
else:
if num_special_tokens == 4:
packed_input_ids.append([self.cls_token_id] + list(input) + [self.gmask_token_id])
else:
packed_input_ids.append(list(input) + [self.gmask_token_id])
packed_output_ids.append(list(output) + [self.eos_token_id])
current_len += len(input) + len(output) + num_special_tokens
assert current_len <= max_length
if use_packed:
# pack模式
def build_mask_matrix(seq_length, sep):
# https://github.com/pytorch/pytorch/issues/101932, fix triu/tril bf16 support
m = torch.ones((1, seq_length, seq_length))
mask = torch.arange(1, m.shape[-1] + 1).reshape(1, -1, 1).to(m.device)
ids = torch.arange(1, m.shape[-1] + 1).reshape(1, 1, -1).expand(1, m.shape[-1], -1).to(m.device)
m = (ids <= mask).type_as(m)
m[0, :, : int(sep)] = 1
m = m.squeeze(0)
return m
tokens = []
attention_mask_list = []
input_length_list = []
position_id_list = []
block_position_id_list = []
for input, output in zip(packed_input_ids, packed_output_ids):
if self.add_bos_token:
data = input + [self.sop_token_id] + output
mask_pos = len(input) - 1
else:
data = input + output
mask_pos = len(input) - 2
if return_attention_mask:
if unidirectional_attention:
attention_mask = build_mask_matrix(len(data), 0)
else:
attention_mask = build_mask_matrix(len(data), len(input))
attention_mask = attention_mask.squeeze((0, 1))
attention_mask_list.append(attention_mask)
input_length_list.append(len(input))
tokens += data
sop_pos = mask_pos + 1
position_ids, block_position_ids = self._build_position_ids(
mask_pos=mask_pos, bos_pos=sop_pos, max_output_length=len(output), rotary_type=rotary_type
)
position_id_list.append(position_ids)
block_position_id_list.append(block_position_ids)
labels = []
for i in range(len(packed_input_ids)):
if self.add_bos_token:
labels += [-100] * len(packed_input_ids[i]) + packed_output_ids[i] + [-100]
else:
labels += [-100] * (len(packed_input_ids[i]) - 1) + packed_output_ids[i] + [-100]
total_len = 0
if use_fa2:
pack_attention_mask = -1 * torch.ones([2, current_len])
else:
pack_attention_mask = torch.tril(torch.ones([current_len, current_len]))
pack_position_ids = []
pack_block_position_ids = []
total_len = 0
max_index = 0
for i in range(len(position_id_list)):
if use_fa2:
pack_attention_mask[0][i] = total_len
pack_attention_mask[1][i] = total_len + input_length_list[i]
else:
pack_attention_mask[
total_len : total_len + attention_mask.shape[0],
total_len : total_len + attention_mask.shape[0],
] = attention_mask
position_ids = [pid + max_index for pid in position_id_list[i]]
block_position_ids = block_position_id_list[i]
pack_position_ids.extend(position_ids)
pack_block_position_ids.extend(block_position_ids)
if not isolation_position_ids:
max_index = pack_position_ids[-1] + 1
total_len += len(position_id_list[i])
position_ids = [pack_position_ids, pack_block_position_ids]
else:
# 单输入模式
# 真多轮下,一条样本可能会有好几轮对话,此时需要获取第一条样本的结束位置
if len(new_conversation_offset) > 1:
end_idx = new_conversation_offset[1]
else:
end_idx = 1
input, output = list(itertools.chain(*packed_input_ids[:end_idx])), list(
itertools.chain(*packed_output_ids[:end_idx])
)
if self.add_bos_token:
tokens = input + [self.sop_token_id] + output
else:
tokens = input + output
if self.add_bos_token:
labels = [-100] * len(input) + output + [-100]
position_ids = self._build_position_ids(
mask_pos=len(input) - 1, bos_pos=len(input), max_output_length=len(output), rotary_type=rotary_type
)
else:
labels = [-100] * (len(input) - 1) + output + [-100]
position_ids = self._build_position_ids(
mask_pos=len(input) - 2,
bos_pos=len(input) - 1,
max_output_length=len(output),
rotary_type=rotary_type,
)
attention_mask = len(input)
assert current_len == len(tokens)
# 最大长度补全
if max_length > 0 and len(tokens) < max_length and padding:
pad_length = max_length - len(tokens)
tokens += [self.pad_token_id] * pad_length
labels.extend([-100] * pad_length)
position_ids[0] += [0] * pad_length
position_ids[1] += [0] * pad_length
if use_packed:
if use_fa2:
new_attention_mask = -1 * torch.ones([2, max_length])
new_attention_mask[:, :current_len] = pack_attention_mask
else:
new_attention_mask = torch.tril(torch.ones([max_length, max_length]))
new_attention_mask[:current_len, :current_len] = pack_attention_mask
pack_attention_mask = new_attention_mask.contiguous()
assert len(tokens) == len(labels)
if max_length > 0 and padding:
assert len(tokens) == max_length
if use_fa2 and unidirectional_attention:
# pack_attention_mask = torch.zeros([1], dtype=torch.long)
pack_attention_mask = torch.tensor(0)
if use_packed:
if not use_fa2:
attention_mask = pack_attention_mask.unsqueeze(0).long()
else:
attention_mask = pack_attention_mask
else:
attention_mask = torch.tensor(attention_mask).long()
return {
'input_ids': torch.tensor(tokens).long(),
'position_ids': torch.tensor(position_ids).long(),
'attention_mask': attention_mask,
'labels': torch.tensor(labels).long(),
}
def _build_baichuan_inputs_for_train(
self,
inputs: Union[str, List[str]],
outputs: Union[str, List[str]],
new_conversation_offset: List[int] = None,
max_length: int = 2048,
rotary_type: str = "1d",
left_truncate: bool = True,
skip_truncated_turn: bool = True,
use_fa2: bool = True,
padding: bool = True,
):
'''
input: <role> HUMAN </role> u1 <role> ASSISTANT </role> a11 a12 <role> HUMAN </role> u2 <role> ASSISTANT </role> a21 a22 <|endoftext|> <role> HUMAN </role> u1 <role> ASSISTANT </role> a11 a12 <role> HUMAN </role> u2 <role> ASSISTANT </role> a21 a22 <|endoftext|>
output: x x x x x x a11 a12 <|endoftext|> x x x x x x a21 a22 <|endoftext|> x x x x x x x a11 a12 <|endoftext|> x x x x x x a21 a22 <|endoftext|> x
只适用真多轮+pack数据训练单向模型需要打开use_true_multiturn
'''
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(outputs, str):
outputs = [outputs]
assert len(inputs) == len(outputs)
input_ids = [self(item)['input_ids'] for item in inputs]
output_ids = [self(item)['input_ids'] for item in outputs]
packed_input_ids = []
packed_output_ids = []
if new_conversation_offset is None:
new_conversation_offset = list(range(0, len(inputs)))
assert 0 in new_conversation_offset, f"没有0请检查new_conversation_offset: {new_conversation_offset}"
current_len = 0
for idx, (input, output) in enumerate(zip(input_ids, output_ids)):
num_special_tokens = 0
if idx != 0 and idx in new_conversation_offset:
# 在input_ids加入eos只有第0条样本不加
num_special_tokens += 1
# truncate
if len(input) + len(output) + current_len > max_length - num_special_tokens:
if use_fa2:
attention_mask = torch.tensor(0)
else:
attention_mask = torch.tril(torch.ones([max_length, max_length]))
# 返回一个空的样本,该样本不参与训练
default_return = {
'input_ids': (torch.ones(max_length) * self.eos_token_id).long(),
'position_ids': torch.zeros(2, max_length).long(),
'attention_mask': (attention_mask.long()),
'labels': (torch.ones(max_length) * -100).long(),
}
# 如果不截断,直接返回
if skip_truncated_turn:
if current_len == 0:
return default_return
else:
break
left_len = max_length - num_special_tokens - current_len
# 如果截断只截断prompt
if left_len - len(output) > 0:
if left_truncate:
input = input[-(left_len - len(output)) :]
else:
input = input[: left_len - len(output)]
else:
# response超过left_len直接返回
if current_len == 0:
return default_return
else:
break
# 这里拼的是input_ids
if num_special_tokens == 1:
packed_input_ids.append([self.eos_token_id] + list(input))
else:
packed_input_ids.append(list(input))
packed_output_ids.append(list(output))
current_len += len(input) + len(output) + num_special_tokens
assert current_len <= max_length
def build_mask_matrix(seq_length, sep):
# https://github.com/pytorch/pytorch/issues/101932, fix triu/tril bf16 support
m = torch.ones((1, seq_length, seq_length))
mask = torch.arange(1, m.shape[-1] + 1).reshape(1, -1, 1).to(m.device)
ids = torch.arange(1, m.shape[-1] + 1).reshape(1, 1, -1).expand(1, m.shape[-1], -1).to(m.device)
m = (ids <= mask).type_as(m)
m[0, :, : int(sep)] = 1
m = m.squeeze(0)
return m
tokens = []
attention_mask_list = []
position_id_list = []
block_position_id_list = []
token_lens = []
for input, output in zip(packed_input_ids, packed_output_ids):
data = input + output
if not use_fa2:
attention_mask = build_mask_matrix(len(data), 0)
attention_mask_list.append(attention_mask)
tokens += data
token_lens.append(len(data))
position_ids, block_position_ids = self._build_position_ids(
mask_pos=len(input) - 2, bos_pos=len(input) - 1, max_output_length=len(output), rotary_type=rotary_type
)
position_id_list.append(position_ids)
block_position_id_list.append(block_position_ids)
labels = []
for i in range(len(packed_input_ids)):
labels += [-100] * (len(packed_input_ids[i]) - 1) + packed_output_ids[i] + [self.eos_token_id]
total_len = 0
if use_fa2:
pack_attention_mask = torch.Tensor([[0], [1]])
else:
pack_attention_mask = torch.tril(torch.ones([max_length, max_length]))
pack_position_ids = []
pack_block_position_ids = []
total_len = 0
max_index = 0
for i in range(len(token_lens)):
if not use_fa2:
attention_mask = attention_mask_list[i]
pack_attention_mask[
total_len : total_len + attention_mask.shape[0], total_len : total_len + attention_mask.shape[0]
] = attention_mask
position_ids = [pid + max_index for pid in position_id_list[i]]
block_position_ids = block_position_id_list[i]
pack_position_ids.extend(position_ids)
pack_block_position_ids.extend(block_position_ids)
max_index = pack_position_ids[-1] + 1
total_len += token_lens[i]
position_ids = [pack_position_ids, pack_block_position_ids]
if max_length > 0 and len(tokens) < max_length and padding:
pad_length = max_length - len(tokens)
tokens += [self.pad_token_id] * pad_length
labels.extend([-100] * pad_length)
position_ids[0] += [0] * pad_length
position_ids[1] += [0] * pad_length
assert len(tokens) == len(labels)
if not use_fa2:
attention_mask = pack_attention_mask.unsqueeze(0).long()
else:
attention_mask = torch.tensor(0)
return {
'input_ids': torch.tensor(tokens).long(),
'position_ids': torch.tensor(position_ids).long(),
'attention_mask': attention_mask,
'labels': torch.tensor(labels).long(),
}
def build_inputs_for_train(
self,
data: Union[Dict, List[Dict]],
new_conversation_offset: List[int] = None,
chat_format="antglm_chat",
is_chat_format=True, # 如果传入的是字符串,用于说明是否已经是
use_true_multiturn=False,
max_length: int = 2048,
rotary_type: str = "1d",
left_truncate: bool = True,
unidirectional_attention: bool = True,
isolation_position_ids: bool = False,
padding: bool = True,
use_fa2: bool = True,
use_packed: bool = True,
use_baichuan_packed: bool = False,
skip_truncated_turn: bool = False,
return_attention_mask: bool = True,
):
r"""
Build tensor input for model training. If inputs and outputs are list, will pack them.
Args:
inputs (str, List[str], List[Dict], List[List[Dict]]): the input prompts.
outputs (str, List[str]): the output responses.
new_conversation_offset (List[int]): the offset index of the new conversation turn.
is_chat_format (bool): whether the input is already chatml format
max_length (int, Optional): the maximum length of the final input ids for training. Default: 2048
rotary_type (str, Optional): the rotary type of position embedding. Default: 1d
left_truncate (bool, Optional): whether truncate the inputs from left. Default: True
use_fa2 (bool, Optional): whether to build attention mask under flash attention 2.
"""
if isinstance(data, List):
# chatml list
_inputs = []
_outputs = []
new_conversation_offset = []
for _input in data:
if use_true_multiturn:
chat = self._chat_from_json(_input, chat_format=chat_format)
chat_data = chat.prompt_pack
new_conversation_offset.append(len(_inputs))
_inputs.extend(chat_data['input'])
_outputs.extend(chat_data['output'])
else:
_conversation = _convert_to_conversation(_input)
assert is_assistant(_conversation[-1])
_inputs.append(
self.apply_chat_template(_conversation[:-1], tokenize=False, add_generation_prompt=True)
)
_outputs.append(_conversation[-1]['content'])
return self._build_inputs_for_train(
inputs=_inputs,
outputs=_outputs,
new_conversation_offset=new_conversation_offset,
max_length=max_length,
rotary_type=rotary_type,
left_truncate=left_truncate,
unidirectional_attention=unidirectional_attention,
isolation_position_ids=isolation_position_ids,
padding=padding,
use_fa2=use_fa2,
use_packed=use_packed,
use_baichuan_packed=use_baichuan_packed,
skip_truncated_turn=skip_truncated_turn,
return_attention_mask=return_attention_mask,
)
elif isinstance(data, Dict):
if 'messages' in data:
# chatml format
if use_true_multiturn:
chat = self._chat_from_json(data, chat_format=chat_format)
chat_data = chat.prompt_pack
else:
_conversation = _convert_to_conversation(data)
assert is_assistant(_conversation[-1])
chat_data = {
"input": self.apply_chat_template(
_conversation[:-1], tokenize=False, add_generation_prompt=True
),
"output": _conversation[-1]['content'],
}
return self._build_inputs_for_train(
inputs=chat_data['input'],
outputs=chat_data['output'],
max_length=max_length,
rotary_type=rotary_type,
left_truncate=left_truncate,
unidirectional_attention=unidirectional_attention,
isolation_position_ids=isolation_position_ids,
padding=padding,
use_fa2=use_fa2,
use_packed=use_packed,
use_baichuan_packed=use_baichuan_packed,
skip_truncated_turn=skip_truncated_turn,
return_attention_mask=return_attention_mask,
)
else:
inputs = data['input']
outputs = data['output']
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(outputs, str):
outputs = [outputs]
if not is_chat_format and chat_format:
inputs = [
self.apply_chat_template(
[{"role": "HUMAN", "content": item}], tokenize=False, chat_format=chat_format
)
for item in inputs
]
return self._build_inputs_for_train(
inputs=inputs,
outputs=outputs,
new_conversation_offset=new_conversation_offset,
max_length=max_length,
rotary_type=rotary_type,
left_truncate=left_truncate,
unidirectional_attention=unidirectional_attention,
isolation_position_ids=isolation_position_ids,
padding=padding,
use_fa2=use_fa2,
use_packed=use_packed,
use_baichuan_packed=use_baichuan_packed,
skip_truncated_turn=skip_truncated_turn,
return_attention_mask=return_attention_mask,
)