初始化项目,由ModelHub XC社区提供模型
Model: GenerTeam/GENERator-v2-eukaryote-3b-base Source: Original Platform
This commit is contained in:
35
.gitattributes
vendored
Normal file
35
.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
183
README.md
Normal file
183
README.md
Normal file
@@ -0,0 +1,183 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: mit
|
||||
pipeline_tag: text-generation
|
||||
tags:
|
||||
- biology
|
||||
- genomics
|
||||
- long-context
|
||||
---
|
||||
|
||||
# GENERator-v2-eukaryote-3b-base model
|
||||
|
||||
## **Important Notice**
|
||||
If you are using **GENERator** for sequence generation, please ensure that the length of each input sequence is a multiple of **6**. This can be achieved by either:
|
||||
1. Padding the sequence on the left with `'A'` (**left padding**);
|
||||
2. Truncating the sequence from the left (**left truncation**).
|
||||
|
||||
This requirement arises because **GENERator** employs a 6-mer tokenizer. If the input sequence length is not a multiple of **6**, the tokenizer will append an `'<oov>'` (out-of-vocabulary) token to the end of the token sequence. This can result in uninformative subsequent generations, such as repeated `'AAAAAA'`.
|
||||
|
||||
We apologize for any inconvenience this may cause and recommend adhering to the above guidelines to ensure accurate and meaningful generation results.
|
||||
|
||||
|
||||
## Abouts
|
||||
In this repository, we present GENERator-v2, a generative genomic foundation with enhanced performance in eukaryotic domain. More technical details are provided in the GENERator-v2 [technical report](https://www.biorxiv.org/content/10.64898/2026.01.27.702015v1).
|
||||
|
||||
Python scripts for downstream analysis are available on Github: [https://github.com/GenerTeam/GENERator](https://github.com/GenerTeam/GENERator).
|
||||
|
||||
|
||||
## How to use
|
||||
### Simple example1: generation
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
# Load the tokenizer and model.
|
||||
tokenizer = AutoTokenizer.from_pretrained("GenerTeam/GENERator-v2-eukaryote-3b-base", trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained("GenerTeam/GENERator-v2-eukaryote-3b-base")
|
||||
config = model.config
|
||||
|
||||
max_length = config.max_position_embeddings
|
||||
|
||||
# Define input sequences.
|
||||
sequences = [
|
||||
"ATGAGGTGGCAAGAAATGGGCTAC",
|
||||
"GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
|
||||
]
|
||||
|
||||
def left_padding(sequence, padding_char='A', multiple=6):
|
||||
remainder = len(sequence) % multiple
|
||||
if remainder != 0:
|
||||
padding_length = multiple - remainder
|
||||
return padding_char * padding_length + sequence
|
||||
return sequence
|
||||
|
||||
def left_truncation(sequence, multiple=6):
|
||||
remainder = len(sequence) % multiple
|
||||
if remainder != 0:
|
||||
return sequence[remainder:]
|
||||
return sequence
|
||||
|
||||
# Apply left_padding to all sequences
|
||||
# padded_sequences = [left_padding(seq) for seq in sequences]
|
||||
|
||||
# Apply left_truncation to all sequences
|
||||
truncated_sequences = [left_truncation(seq) for seq in sequences]
|
||||
|
||||
# Process the sequences
|
||||
sequences = [tokenizer.bos_token + sequence for sequence in truncated_sequences]
|
||||
|
||||
# Tokenize the sequences
|
||||
tokenizer.padding_side = "left"
|
||||
inputs = tokenizer(
|
||||
sequences,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
# Generate the sequences
|
||||
with torch.inference_mode():
|
||||
outputs = model.generate(**inputs, max_new_tokens=32, temperature=0.00001, top_k=1)
|
||||
|
||||
# Decode the generated sequences
|
||||
decoded_sequences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
# Print the decoded sequences
|
||||
print(decoded_sequences)
|
||||
|
||||
# It is expected to observe non-sense decoded sequences (e.g., 'AAAAAA')
|
||||
# The input sequences are too short to provide sufficient context.
|
||||
```
|
||||
|
||||
### Simple example2: embedding
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
# Load the tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained("GENERator-v2-eukaryote-3b-base", trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained("GENERator-v2-eukaryote-3b-base")
|
||||
|
||||
# Get model configuration
|
||||
config = model.config
|
||||
max_length = config.max_position_embeddings
|
||||
|
||||
# Define input sequences
|
||||
sequences = [
|
||||
"ATGAGGTGGCAAGAAATGGGCTAC",
|
||||
"GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
|
||||
]
|
||||
|
||||
# Truncate each sequence to the nearest multiple of 6
|
||||
processed_sequences = [tokenizer.bos_token + seq[:len(seq)//6*6] for seq in sequences]
|
||||
|
||||
# Tokenization
|
||||
tokenizer.padding_side = "right"
|
||||
inputs = tokenizer(
|
||||
processed_sequences,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
# Model Inference
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
|
||||
# Option 1: Last token (EOS) embedding
|
||||
last_token_indices = attention_mask.sum(dim=1) - 1
|
||||
eos_embeddings = hidden_states[torch.arange(hidden_states.size(0)), last_token_indices, :]
|
||||
|
||||
# Option 2: Mean pooling over all tokens
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).to(torch.float32)
|
||||
sum_embeddings = torch.sum(hidden_states * expanded_mask, dim=1)
|
||||
mean_embeddings = sum_embeddings / expanded_mask.sum(dim=1)
|
||||
|
||||
# Output
|
||||
print("EOS (Last Token) Embeddings:", eos_embeddings)
|
||||
print("Mean Pooling Embeddings:", mean_embeddings)
|
||||
|
||||
# ============================================================================
|
||||
# Additional notes:
|
||||
# - The preprocessing step ensures sequences are multiples of 6 for 6-mer tokenizer
|
||||
# - For causal LM, the last token embedding (EOS) is commonly used
|
||||
# - Mean pooling considers all tokens including BOS and content tokens
|
||||
# - The choice depends on your downstream task requirements
|
||||
# - Both methods handle variable sequence lengths via attention mask
|
||||
# ============================================================================
|
||||
|
||||
```
|
||||
|
||||
## Citation
|
||||
```
|
||||
@article {li2026generator2,
|
||||
author = {Li, Qiuyi and Zhan, Zhihao and Feng, Shikun and Zhu, Yiheng and He, Yuan and Wu, Wei and Shi, Zhenghang and Wang, Shengjie and Hu, Zongyong and Yang, Zhao and Li, Jiaoyang and Tang, Jian and Liu, Haiguang and Qin, Tao},
|
||||
title = {GENERator-v2: Reconciling Coarse Tokenization with Single-Nucleotide Resolution in Genomic Language Modeling},
|
||||
elocation-id = {2026.01.27.702015},
|
||||
year = {2026},
|
||||
doi = {10.64898/2026.01.27.702015},
|
||||
publisher = {Cold Spring Harbor Laboratory},
|
||||
URL = {https://www.biorxiv.org/content/early/2026/05/04/2026.01.27.702015},
|
||||
journal = {bioRxiv}
|
||||
}
|
||||
|
||||
@article{wu2025generator,
|
||||
title={GENERator: a long-context generative genomic foundation model},
|
||||
author={Wu, Wei and Li, Qiuyi and Li, Mingyang and Fu, Kun and Feng, Fuli and Ye, Jieping and Xiong, Hui and Wang, Zheng},
|
||||
journal={arXiv preprint arXiv:2502.07272},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
```
|
||||
31
config.json
Normal file
31
config.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"architectures": [
|
||||
"GENERatorForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8448,
|
||||
"max_position_embeddings": 16384,
|
||||
"mlp_bias": false,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 30,
|
||||
"num_key_value_heads": 4,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 500000.0,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 4128,
|
||||
"auto_map": {
|
||||
"AutoModelForCausalLM": "modeling_generator.GENERatorForCausalLM"
|
||||
}
|
||||
}
|
||||
6
generation_config.json
Normal file
6
generation_config.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"transformers_version": "4.44.0"
|
||||
}
|
||||
3
model-00001-of-00003.safetensors
Normal file
3
model-00001-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9ee2725691b536b4b783971e1f4edd7c85c7b55ee6274941054c64ca979b6ebc
|
||||
size 4996117216
|
||||
3
model-00002-of-00003.safetensors
Normal file
3
model-00002-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9bf8118e434b5877d14a777d26d588f76d0d4fc0c18f5afa4f3eac4fad7f292b
|
||||
size 4964291160
|
||||
3
model-00003-of-00003.safetensors
Normal file
3
model-00003-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:654dee9a787a61aa762738556059d896841ee62d85abcdf9d99ed8db98536c4a
|
||||
size 2032674024
|
||||
280
model.safetensors.index.json
Normal file
280
model.safetensors.index.json
Normal file
@@ -0,0 +1,280 @@
|
||||
{
|
||||
"metadata": {
|
||||
"total_size": 11993051136
|
||||
},
|
||||
"weight_map": {
|
||||
"lm_head.weight": "model-00003-of-00003.safetensors",
|
||||
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||
"model.norm.weight": "model-00003-of-00003.safetensors"
|
||||
}
|
||||
}
|
||||
236
modeling_generator.py
Normal file
236
modeling_generator.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
GENERator with bp-level generation and scoring.
|
||||
|
||||
generate_bp() plugs into the standard HF generate() pipeline via a
|
||||
LogitsProcessor — no internal methods are overridden, so it is compatible
|
||||
with any transformers version.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import LlamaForCausalLM, LogitsProcessor, LogitsProcessorList
|
||||
from typing import Union
|
||||
|
||||
BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
|
||||
IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
|
||||
|
||||
|
||||
class _BPLogitsProcessor(LogitsProcessor):
|
||||
"""Forces token selection to use per-base marginal probabilities.
|
||||
|
||||
Runs LAST in the logits-processor chain so that temperature / top-k /
|
||||
top-p etc. influence the marginal distributions before base selection.
|
||||
"""
|
||||
|
||||
def __init__(self, kmer_ids, bp_base_index, flat_idx_to_token_id, bp_powers, k, do_sample):
|
||||
self.kmer_ids = kmer_ids
|
||||
self.bp_base_index = bp_base_index
|
||||
self.flat_idx_to_token_id = flat_idx_to_token_id
|
||||
self.bp_powers = bp_powers
|
||||
self.k = k
|
||||
self.do_sample = do_sample
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
B = scores.shape[0]
|
||||
kmer_probs = F.softmax(scores[:, self.kmer_ids].float(), dim=-1) # [B, num_kmers]
|
||||
|
||||
# Marginalise to per-base probabilities [B, k, 4]
|
||||
bp_probs = torch.zeros(B, self.k, 4, device=scores.device, dtype=kmer_probs.dtype)
|
||||
for pos in range(self.k):
|
||||
idx = self.bp_base_index[pos] # [num_kmers] in {0,1,2,3}
|
||||
for nt in range(4):
|
||||
bp_probs[:, pos, nt] = kmer_probs[:, idx == nt].sum(dim=-1)
|
||||
|
||||
if self.do_sample:
|
||||
base_indices = torch.multinomial(bp_probs.view(-1, 4), 1).view(B, self.k)
|
||||
else:
|
||||
base_indices = bp_probs.argmax(dim=-1) # [B, k]
|
||||
|
||||
flat_idx = (base_indices * self.bp_powers).sum(dim=-1) # [B]
|
||||
selected = self.flat_idx_to_token_id[flat_idx] # [B]
|
||||
|
||||
# One-hot: both argmax and multinomial land on the bp-selected token
|
||||
new_scores = torch.full_like(scores, float("-inf"))
|
||||
new_scores.scatter_(1, selected.unsqueeze(1), 0.0)
|
||||
return new_scores
|
||||
|
||||
|
||||
class GENERatorForCausalLM(LlamaForCausalLM):
|
||||
"""LlamaForCausalLM with bp-level autoregressive generation.
|
||||
|
||||
Inherits all standard functionality (forward, generate, etc.)
|
||||
and adds generate_bp() for base-pair independent generation.
|
||||
|
||||
The tokenizer is automatically set up when loading the model with from_pretrained().
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""Load model and automatically setup tokenizer if available."""
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
|
||||
model_path = args[0] if len(args) > 0 else kwargs.get('pretrained_model_name_or_path')
|
||||
|
||||
if model_path:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.setup_tokenizer(tokenizer)
|
||||
print(f"Tokenizer automatically loaded and configured for bp-level scoring")
|
||||
except Exception as e:
|
||||
print(f"Could not auto-load tokenizer: {e}")
|
||||
print(f" Call model.setup_tokenizer(tokenizer) manually if needed")
|
||||
|
||||
return model
|
||||
|
||||
def setup_tokenizer(self, tokenizer):
|
||||
"""Cache tokenizer and precompute lookup tables for bp-level operations."""
|
||||
self.tokenizer = tokenizer
|
||||
k = tokenizer.k
|
||||
self.k = k
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Build ordered kmer list from the tokenizer's DNA vocab
|
||||
kmer_items = sorted(
|
||||
[
|
||||
(kmer, tid)
|
||||
for kmer, tid in tokenizer.vocab.items()
|
||||
if len(kmer) == k and all(b in "ATCG" for b in kmer)
|
||||
],
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
kmers = [item[0] for item in kmer_items]
|
||||
kmer_ids = [item[1] for item in kmer_items]
|
||||
num_kmers = len(kmer_ids)
|
||||
|
||||
kmer_ids_tensor = torch.tensor(kmer_ids, dtype=torch.long, device=device)
|
||||
self.register_buffer("_kmer_ids", kmer_ids_tensor, persistent=False)
|
||||
|
||||
# bp_base_index[pos, j] = base index (0-3) of kmer j at position pos
|
||||
bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
|
||||
for j, kmer in enumerate(kmers):
|
||||
for pos, base in enumerate(kmer):
|
||||
bp_base_index[pos, j] = BASE_TO_IDX[base]
|
||||
self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
|
||||
|
||||
bp_powers = torch.tensor(
|
||||
[4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
|
||||
)
|
||||
self.register_buffer("_bp_powers", bp_powers, persistent=False)
|
||||
|
||||
# flat kmer index -> token id (flat index = sum base_idx[i] * 4^(k-1-i))
|
||||
flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
|
||||
for j, (kmer, tid) in enumerate(kmer_items):
|
||||
flat_idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
|
||||
flat_to_tid[flat_idx] = tid
|
||||
self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
|
||||
|
||||
def compute_bp_probs(self, logits):
|
||||
"""Compute per-base marginal probabilities from token logits.
|
||||
|
||||
Args:
|
||||
logits: [B, V] or [B, L, V]
|
||||
Returns:
|
||||
bp_probs: [B, k, 4] or [B, L, k, 4]
|
||||
"""
|
||||
squeeze = logits.dim() == 2
|
||||
if squeeze:
|
||||
logits = logits.unsqueeze(1)
|
||||
|
||||
kmer_logits = logits[:, :, self._kmer_ids]
|
||||
kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
|
||||
B, L, _ = kmer_probs.shape
|
||||
bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
|
||||
for pos in range(self.k):
|
||||
idx = self._bp_base_index[pos]
|
||||
for nt in range(4):
|
||||
bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
|
||||
|
||||
return bp_probs.squeeze(1) if squeeze else bp_probs
|
||||
|
||||
def generate(self, inputs=None, generation_config=None, **kwargs):
|
||||
"""Like generate(), but each token is selected base-by-base from marginal distributions.
|
||||
|
||||
Temperature, top_k, top_p, repetition_penalty etc. all apply as usual —
|
||||
they run before the bp processor and shift the marginal distributions.
|
||||
Output shape and type are identical to generate().
|
||||
"""
|
||||
assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer(tokenizer) first"
|
||||
|
||||
gc = generation_config or self.generation_config
|
||||
do_sample = kwargs.get("do_sample", getattr(gc, "do_sample", False))
|
||||
|
||||
bp_proc = _BPLogitsProcessor(
|
||||
kmer_ids=self._kmer_ids,
|
||||
bp_base_index=self._bp_base_index,
|
||||
flat_idx_to_token_id=self._flat_idx_to_token_id,
|
||||
bp_powers=self._bp_powers,
|
||||
k=self.k,
|
||||
do_sample=do_sample,
|
||||
)
|
||||
existing = list(kwargs.pop("logits_processor", None) or [])
|
||||
kwargs["logits_processor"] = LogitsProcessorList(existing + [bp_proc])
|
||||
|
||||
return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def score_sequence(self, sequences: Union[str, list]):
|
||||
"""Score DNA sequence(s) at base resolution.
|
||||
|
||||
Returns per-base probability distributions and the probability of the
|
||||
actual base at each position, given all preceding context.
|
||||
|
||||
Args:
|
||||
sequences: single DNA string or list of DNA strings (ACGT only)
|
||||
|
||||
Returns:
|
||||
(bp_probs, actual_probs) for a single sequence, or
|
||||
(list of bp_probs, list of actual_probs) for a batch.
|
||||
bp_probs[i]: [seq_len_i, 4] — P(base | context) at each position
|
||||
actual_probs[i]: [seq_len_i] — P(actual base | context)
|
||||
"""
|
||||
assert hasattr(self, "tokenizer"), "Call setup_tokenizer(tokenizer) first"
|
||||
|
||||
is_single = isinstance(sequences, str)
|
||||
if is_single:
|
||||
sequences = [sequences]
|
||||
|
||||
original_lens = [len(s) for s in sequences]
|
||||
|
||||
# Right-pad to multiple of k with 'A' (matches tokenizer convention)
|
||||
padded = []
|
||||
for s in sequences:
|
||||
r = len(s) % self.k
|
||||
padded.append(s + "A" * (self.k - r) if r else s)
|
||||
|
||||
# Prepend BOS manually (training format)
|
||||
tagged = ["<s>" + s for s in padded]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
tagged, return_tensors="pt", padding=True, add_special_tokens=False
|
||||
)
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
attention_mask = inputs["attention_mask"].to(self.device)
|
||||
|
||||
logits = self(input_ids, attention_mask=attention_mask, return_dict=True).logits
|
||||
bp_probs_all = self.compute_bp_probs(logits) # [B, L, k, 4]
|
||||
|
||||
bp_results, actual_results = [], []
|
||||
for i, (seq, orig_len, pad_seq) in enumerate(zip(sequences, original_lens, padded)):
|
||||
num_tokens = len(pad_seq) // self.k
|
||||
# logits[t] predicts token t+1; logits[0] (from <s>) predicts token 1
|
||||
seq_bp = bp_probs_all[i, :num_tokens] # [num_tokens, k, 4]
|
||||
seq_bp = seq_bp.reshape(-1, 4)[:orig_len] # [orig_len, 4]
|
||||
actual = self._extract_actual_probs(seq_bp, seq)
|
||||
bp_results.append(seq_bp)
|
||||
actual_results.append(actual)
|
||||
|
||||
if is_single:
|
||||
return bp_results[0], actual_results[0]
|
||||
return bp_results, actual_results
|
||||
|
||||
def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str) -> torch.Tensor:
|
||||
actual = torch.zeros(len(sequence), device=bp_probs.device, dtype=bp_probs.dtype)
|
||||
for i, base in enumerate(sequence):
|
||||
actual[i] = bp_probs[i].max() if base == "N" else bp_probs[i, BASE_TO_IDX[base]]
|
||||
return actual
|
||||
30
special_tokens_map.json
Normal file
30
special_tokens_map.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<oov>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
163
tokenizer.py
Normal file
163
tokenizer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import itertools
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
class DNAKmerTokenizer(PreTrainedTokenizer):
|
||||
def __init__(self, k, **kwargs):
|
||||
self.k = k
|
||||
self.special_tokens = [
|
||||
"<oov>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<pad>",
|
||||
"<mask>",
|
||||
"<bog>",
|
||||
"<eog>",
|
||||
"<bok>",
|
||||
"<eok>",
|
||||
"<+>",
|
||||
"<->",
|
||||
"<cds>",
|
||||
"<pseudo>",
|
||||
"<tRNA>",
|
||||
"<rRNA>",
|
||||
"<ncRNA>",
|
||||
"<miscRNA>",
|
||||
"<mam>",
|
||||
"<vrt>",
|
||||
"<inv>",
|
||||
"<pln>",
|
||||
"<fng>",
|
||||
"<prt>",
|
||||
"<arc>",
|
||||
"<bct>",
|
||||
"<mit>",
|
||||
"<plt>",
|
||||
"<plm>",
|
||||
"<vir>",
|
||||
"<sp0>",
|
||||
"<sp1>",
|
||||
"<sp2>",
|
||||
]
|
||||
self.kmers = [
|
||||
"".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k)
|
||||
]
|
||||
self.vocab = {
|
||||
token: i for i, token in enumerate(self.special_tokens + self.kmers)
|
||||
}
|
||||
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
|
||||
self.special_token_pattern = re.compile(
|
||||
"|".join(re.escape(token) for token in self.special_tokens)
|
||||
)
|
||||
self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+")
|
||||
kwargs.setdefault("bos_token", "<s>")
|
||||
kwargs.setdefault("eos_token", "</s>")
|
||||
kwargs.setdefault("unk_token", "<oov>")
|
||||
kwargs.setdefault("pad_token", "<pad>")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab)
|
||||
|
||||
def _tokenize(self, text, **kwargs) -> List[str]:
|
||||
tokens = []
|
||||
pos = 0
|
||||
while pos < len(text):
|
||||
special_match = self.special_token_pattern.match(text, pos)
|
||||
if special_match:
|
||||
tokens.append(special_match.group())
|
||||
pos = special_match.end()
|
||||
else:
|
||||
dna_match = self.dna_pattern.match(text, pos)
|
||||
if dna_match:
|
||||
dna_seq = dna_match.group()
|
||||
tokens.append(dna_seq)
|
||||
pos = dna_match.end()
|
||||
else:
|
||||
tokens.append(text[pos])
|
||||
pos += 1
|
||||
return tokens
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
return self.vocab.get(token, self.vocab["<oov>"])
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
return self.ids_to_tokens.get(index, "<oov>")
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return "".join(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if token_ids_1 is None:
|
||||
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
|
||||
):
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0, token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def prepare_for_model(self, *args, **kwargs):
|
||||
encoding = super().prepare_for_model(*args, **kwargs)
|
||||
if "token_type_ids" in encoding:
|
||||
del encoding["token_type_ids"]
|
||||
return encoding
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
import os
|
||||
|
||||
vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
|
||||
)
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
writer.write(token + "\n")
|
||||
return (vocab_file,)
|
||||
|
||||
def save_pretrained(self, save_directory: str, **kwargs):
|
||||
vocab_files = super().save_pretrained(save_directory, **kwargs)
|
||||
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
|
||||
|
||||
# 读取现有的配置或创建新的
|
||||
if os.path.exists(tokenizer_config_path):
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# 添加auto_map配置
|
||||
config.update({
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenizer.DNAKmerTokenizer",
|
||||
None
|
||||
]
|
||||
},
|
||||
})
|
||||
|
||||
# 添加kmer配置
|
||||
config.update({
|
||||
"k": self.k
|
||||
})
|
||||
|
||||
# 保存配置
|
||||
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return vocab_files
|
||||
60
tokenizer_config.json
Normal file
60
tokenizer_config.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<oov>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenizer.DNAKmerTokenizer",
|
||||
null
|
||||
]
|
||||
},
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_special_tokens": {},
|
||||
"kmer": 6,
|
||||
"legacy": true,
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "DNAKmerTokenizer",
|
||||
"unk_token": "<oov>",
|
||||
"use_default_system_prompt": false,
|
||||
"use_fast": false,
|
||||
"k": 6
|
||||
}
|
||||
Reference in New Issue
Block a user