Suppport qwen model and solve some problems (#75)

This commit is contained in:
Arcmoon
2024-01-23 12:14:51 +08:00
committed by GitHub
parent e08bca2840
commit 63e97e5e4c
7 changed files with 274 additions and 4 deletions

View File

@@ -108,9 +108,11 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode(t_id).strip()
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
@@ -214,4 +216,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
return image