13 lines
311 B
Python
13 lines
311 B
Python
|
|
from transformers import AutoTokenizer, GPT2LMHeadModel
|
||
|
|
|
||
|
|
'''
|
||
|
|
This is a script to convert the Jax model and the tokenizer to Pytorch model
|
||
|
|
'''
|
||
|
|
|
||
|
|
model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
|
||
|
|
model.save_pretrained(".")
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(".")
|
||
|
|
tokenizer.save_pretrained(".")
|
||
|
|
|