Add skip_tokenizer_init args. (#959)
Co-authored-by: lzhang <zhanglei@modelbest.cn>
This commit is contained in:
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
|
||||
class FSMCache(BaseToolCache):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path,
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
||||
if (
|
||||
skip_tokenizer_init
|
||||
or tokenizer_path.endswith(".json")
|
||||
or tokenizer_path.endswith(".model")
|
||||
):
|
||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||
return
|
||||
|
||||
|
||||
@@ -59,11 +59,14 @@ class DetokenizerManager:
|
||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.decode_status = {}
|
||||
|
||||
@@ -85,6 +88,11 @@ class DetokenizerManager:
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
if self.tokenizer is None:
|
||||
# Send BatchTokenIDOut if no tokenizer init'ed.
|
||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||
continue
|
||||
|
||||
# Initialize decode status
|
||||
read_ids, surr_ids = [], []
|
||||
for i in range(bs):
|
||||
|
||||
@@ -195,6 +195,8 @@ class Req:
|
||||
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
||||
|
||||
def get_next_inc_detokenization(self):
|
||||
if self.tokenizer is None:
|
||||
return False, ""
|
||||
read_ids, read_offset = self.init_incremental_detokenize()
|
||||
surr_ids = read_ids[:read_offset]
|
||||
|
||||
@@ -225,16 +227,11 @@ class Req:
|
||||
return
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
if (
|
||||
last_token_id == self.tokenizer.eos_token_id
|
||||
and not self.sampling_params.ignore_eos
|
||||
):
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(
|
||||
matched=self.tokenizer.eos_token_id
|
||||
)
|
||||
return
|
||||
|
||||
if last_token_id in self.sampling_params.stop_token_ids:
|
||||
if self.tokenizer is None:
|
||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||
else:
|
||||
matched_eos = last_token_id == self.tokenizer.eos_token_id
|
||||
if matched_eos and not self.sampling_params.ignore_eos:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
|
||||
|
||||
@@ -95,25 +95,28 @@ class TokenizerManager:
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if is_multimodal_model(self.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
@@ -171,6 +174,7 @@ class TokenizerManager:
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
if obj.input_ids is None:
|
||||
assert self.tokenizer is not None
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
||||
@@ -207,7 +211,20 @@ class TokenizerManager:
|
||||
else:
|
||||
input_text = obj.text
|
||||
rid = obj.rid[0]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
if self.tokenizer is not None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
assert obj.input_ids is not None
|
||||
input_ids = obj.input_ids
|
||||
if isinstance(obj.input_ids, list) and isinstance(
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
rid = obj.rid[0]
|
||||
else:
|
||||
input_text = None
|
||||
if isinstance(obj.input_ids, list) and isinstance(
|
||||
@@ -420,7 +437,7 @@ class TokenizerManager:
|
||||
# Log requests
|
||||
if self.server_args.log_requests and state.finished:
|
||||
if obj.text is None:
|
||||
in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
|
||||
in_obj = {"input_ids": obj.input_ids}
|
||||
else:
|
||||
in_obj = {"text": obj.text}
|
||||
logger.info(f"in={in_obj}, out={out}")
|
||||
@@ -488,11 +505,12 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
|
||||
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
|
||||
await self.recv_from_detokenizer.recv_pyobj()
|
||||
)
|
||||
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||
), f"Unexpected obj received: {type(recv_obj)}"
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
@@ -504,6 +522,15 @@ class TokenizerManager:
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.decode_ids[
|
||||
read_start : recv_obj.read_offsets[i]
|
||||
],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -549,6 +576,7 @@ class TokenizerManager:
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
assert self.tokenizer is not None
|
||||
token_ids = [tid for _, tid in token_logprobs]
|
||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||
return [
|
||||
|
||||
@@ -100,20 +100,22 @@ class ModelTpServer:
|
||||
nccl_port=nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||
self.max_prefill_tokens = (
|
||||
16384
|
||||
@@ -182,13 +184,15 @@ class ModelTpServer:
|
||||
self.last_stats_tic = time.time()
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
)
|
||||
if not server_args.skip_tokenizer_init:
|
||||
self.regex_fsm_cache = FSMCache(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
)
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
# Init new token estimation
|
||||
@@ -466,7 +470,11 @@ class ModelTpServer:
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
if self.tokenizer is None:
|
||||
for i, req in enumerate(batch.reqs):
|
||||
next_token_ids.extend(req.sampling_params.stop_token_ids)
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
|
||||
@@ -111,13 +111,19 @@ class SamplingParams:
|
||||
# Process stop strings
|
||||
if self.stop_strs is None:
|
||||
self.stop_strs = []
|
||||
self.stop_str_max_len = 0
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_str_max_len = 0
|
||||
else:
|
||||
self.stop_str_max_len = 1
|
||||
else:
|
||||
if isinstance(self.stop_strs, str):
|
||||
self.stop_strs = [self.stop_strs]
|
||||
|
||||
stop_str_max_len = 0
|
||||
for stop_str in self.stop_strs:
|
||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||
if tokenizer is not None:
|
||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||
else:
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
|
||||
@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
}
|
||||
if server_args.skip_tokenizer_init:
|
||||
json_data["input_ids"] = [10, 11, 12]
|
||||
else:
|
||||
json_data["text"] = "The capital city of France is"
|
||||
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ class ServerArgs:
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
load_format: str = "auto"
|
||||
dtype: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
@@ -151,6 +152,11 @@ class ServerArgs:
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-tokenizer-init",
|
||||
action="store_true",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
|
||||
@@ -197,6 +197,8 @@ def allocate_init_ports(
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
"""Get the logit bias for integer-only tokens."""
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
if tokenizer == None:
|
||||
return [-1e5] * vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
|
||||
Reference in New Issue
Block a user