Suppport qwen model and solve some problems (#75)
This commit is contained in:
@@ -108,9 +108,11 @@ def get_exception_traceback():
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
||||
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
ss = tokenizer.decode(t_id).strip()
|
||||
ss = tokenizer.decode([t_id]).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
# else:
|
||||
@@ -214,4 +216,4 @@ def load_image(image_file):
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
return image
|
||||
Reference in New Issue
Block a user