Increase the number of thread limitation for tp worker managers. (#567)

This commit is contained in:
Lianmin Zheng
2024-06-26 09:33:45 -07:00
committed by GitHub
parent a385ee27bd
commit 2e6e62e156
9 changed files with 148 additions and 84 deletions

View File

@@ -88,6 +88,9 @@ def get_tokenizer(
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
if tokenizer_name.endswith(".model"):
return SentencePieceTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
@@ -179,6 +182,7 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
from jinja2 import Template
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
@@ -216,6 +220,7 @@ class TiktokenTokenizer:
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._default_allowed_special |= {"<|separator|>"}
def encode_patched(
self,
@@ -241,6 +246,9 @@ class TiktokenTokenizer:
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
@@ -255,7 +263,39 @@ class TiktokenTokenizer:
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
return self.encode(ret) if tokenize else ret
class SentencePieceTokenizer:
def __init__(self, tokenizer_path):
import sentencepiece as spm
from jinja2 import Template
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_id()
self.vocab_size = tokenizer.vocab_size()
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
return self.encode(ret) if tokenize else ret