初始化项目,由ModelHub XC社区提供模型
Model: GenerTeam/GENERator-v2-eukaryote-1.2b-base Source: Original Platform
This commit is contained in:
163
tokenizer.py
Normal file
163
tokenizer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import itertools
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
class DNAKmerTokenizer(PreTrainedTokenizer):
|
||||
def __init__(self, k, **kwargs):
|
||||
self.k = k
|
||||
self.special_tokens = [
|
||||
"<oov>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<pad>",
|
||||
"<mask>",
|
||||
"<bog>",
|
||||
"<eog>",
|
||||
"<bok>",
|
||||
"<eok>",
|
||||
"<+>",
|
||||
"<->",
|
||||
"<cds>",
|
||||
"<pseudo>",
|
||||
"<tRNA>",
|
||||
"<rRNA>",
|
||||
"<ncRNA>",
|
||||
"<miscRNA>",
|
||||
"<mam>",
|
||||
"<vrt>",
|
||||
"<inv>",
|
||||
"<pln>",
|
||||
"<fng>",
|
||||
"<prt>",
|
||||
"<arc>",
|
||||
"<bct>",
|
||||
"<mit>",
|
||||
"<plt>",
|
||||
"<plm>",
|
||||
"<vir>",
|
||||
"<sp0>",
|
||||
"<sp1>",
|
||||
"<sp2>",
|
||||
]
|
||||
self.kmers = [
|
||||
"".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k)
|
||||
]
|
||||
self.vocab = {
|
||||
token: i for i, token in enumerate(self.special_tokens + self.kmers)
|
||||
}
|
||||
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
|
||||
self.special_token_pattern = re.compile(
|
||||
"|".join(re.escape(token) for token in self.special_tokens)
|
||||
)
|
||||
self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+")
|
||||
kwargs.setdefault("bos_token", "<s>")
|
||||
kwargs.setdefault("eos_token", "</s>")
|
||||
kwargs.setdefault("unk_token", "<oov>")
|
||||
kwargs.setdefault("pad_token", "<pad>")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab)
|
||||
|
||||
def _tokenize(self, text, **kwargs) -> List[str]:
|
||||
tokens = []
|
||||
pos = 0
|
||||
while pos < len(text):
|
||||
special_match = self.special_token_pattern.match(text, pos)
|
||||
if special_match:
|
||||
tokens.append(special_match.group())
|
||||
pos = special_match.end()
|
||||
else:
|
||||
dna_match = self.dna_pattern.match(text, pos)
|
||||
if dna_match:
|
||||
dna_seq = dna_match.group()
|
||||
tokens.append(dna_seq)
|
||||
pos = dna_match.end()
|
||||
else:
|
||||
tokens.append(text[pos])
|
||||
pos += 1
|
||||
return tokens
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
return self.vocab.get(token, self.vocab["<oov>"])
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
return self.ids_to_tokens.get(index, "<oov>")
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return "".join(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if token_ids_1 is None:
|
||||
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
|
||||
):
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0, token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def prepare_for_model(self, *args, **kwargs):
|
||||
encoding = super().prepare_for_model(*args, **kwargs)
|
||||
if "token_type_ids" in encoding:
|
||||
del encoding["token_type_ids"]
|
||||
return encoding
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
import os
|
||||
|
||||
vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
|
||||
)
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
writer.write(token + "\n")
|
||||
return (vocab_file,)
|
||||
|
||||
def save_pretrained(self, save_directory: str, **kwargs):
|
||||
vocab_files = super().save_pretrained(save_directory, **kwargs)
|
||||
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
|
||||
|
||||
# 读取现有的配置或创建新的
|
||||
if os.path.exists(tokenizer_config_path):
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# 添加auto_map配置
|
||||
config.update({
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenizer.DNAKmerTokenizer",
|
||||
None
|
||||
]
|
||||
},
|
||||
})
|
||||
|
||||
# 添加kmer配置
|
||||
config.update({
|
||||
"k": self.k
|
||||
})
|
||||
|
||||
# 保存配置
|
||||
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return vocab_files
|
||||
Reference in New Issue
Block a user