Files
sglang/python/sglang/srt/managers/tokenizer_manager.py

786 lines
29 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import dataclasses
import json
import logging
import os
import signal
import sys
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import uvloop
import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ReqState:
"""Store the state a request."""
out_list: List
finished: bool
event: asyncio.Event
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
self.server_args = server_args
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
self.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.hf_config.architectures):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
# For update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
# Others
self.gracefully_exit = False
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0.001)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
)
obj.post_init()
is_single = obj.is_single
if is_single:
async for response in self._handle_single_request(obj, request):
yield response
else:
async for response in self._handle_batch_request(obj, request):
yield response
async def _send_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
index: Optional[int] = None,
input_id_index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
if not is_cache_for_prefill: # The normal case with a single prompt
if index is None:
rid = obj.rid
if hasattr(obj, "conv"):
# reward model
conv = obj.conv
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text
input_ids = self.tokenizer.encode(input_text)
else:
input_text = obj.text if obj.text is not None else None
input_ids = obj.input_ids
sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
else:
rid = obj.rid[index]
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[index]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[input_id_index]
input_ids = self.tokenizer.encode(input_text)
else:
input_text = (
obj.text[input_id_index] if obj.text is not None else None
)
input_ids = obj.input_ids[input_id_index]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index]
self._validate_input_length(input_ids)
else: # A prefill request to cache the common prompt for parallel sampling
assert self.is_generation
if obj.text is not None:
if isinstance(obj.text, list):
input_text = obj.text[input_id_index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
if self.tokenizer is not None:
input_ids = self.tokenizer.encode(input_text)
else:
assert obj.input_ids is not None
input_ids = obj.input_ids
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[input_id_index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
else:
input_text = None
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[input_id_index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
# Send to the controller
if self.is_generation:
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
(
obj.lora_path[input_id_index]
if isinstance(obj.lora_path, list)
else obj.lora_path
),
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_scheduler.send_pyobj(tokenized_obj)
return rid, input_ids
async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
input_id_index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
rid, input_ids = await self._send_single_request(
obj,
index,
input_id_index=input_id_index,
is_cache_for_prefill=is_cache_for_prefill,
)
# Recv results
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
if not is_cache_for_prefill:
async for response in self._wait_for_response(state, obj, rid, request):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(state, obj, rid, request)
yield input_ids
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
batch_size = obj.batch_size
if self.is_generation:
parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1:
# Send prefill requests to cache the common prefix
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
async for input_id in self._handle_single_request(
obj,
request,
index=i,
input_id_index=i,
is_cache_for_prefill=True,
):
if input_id_result is not None:
input_id_result.append(input_id)
if input_id_result is not None:
obj.input_ids = input_id_result
else:
parallel_sample_num = 1
# First send out all requests
generators = []
for i in range(batch_size):
for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i
rid, _ = await self._send_single_request(
obj, index, input_id_index=i, is_cache_for_prefill=False
)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
generators.append(
self._wait_for_response(
state,
obj,
rid,
request,
index=index,
response_index=len(generators),
)
)
# Then process the responses based on streaming option
is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [None] * len(tasks)
# Fetch results
while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
cur_index = tasks.index(task)
try:
result = task.result()
if is_stream:
yield result
else:
output_list[result["index"]] = result
tasks[cur_index] = asyncio.create_task(
generators[cur_index].__anext__()
)
except StopAsyncIteration:
del generators[cur_index]
del tasks[cur_index]
if not is_stream:
yield output_list
def _validate_input_length(self, input_ids: List[int]):
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
def _get_sampling_params(self, sampling_params_data: dict):
sampling_params = SamplingParams(**sampling_params_data)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
return sampling_params
async def _wait_for_response(
self,
state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
rid: str,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
response_index: int = 0,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in [obj.rid] if obj.is_single else obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
if self.is_generation:
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob if index is None else obj.return_logprob[index],
(
obj.top_logprobs_num
if index is None
else obj.top_logprobs_num[index]
),
obj.return_text_in_logprobs,
)
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
out = state.out_list[-1]
out["index"] = response_index
# Log requests
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj}, out={out}")
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
yield out
break
state.event.clear()
yield out
async def _wait_for_cache_prefill_response(
self,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request: Optional[fastapi.Request] = None,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
assert state.finished
del self.rid_to_state[rid]
def flush_cache(self):
req = FlushCacheReq()
self.send_to_scheduler.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def get_memory_pool_size(self):
if self.to_create_loop:
self.create_handle_loop()
req = GetMemPoolSizeReq()
ret = None
if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size
ret = res.size
else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
return ret
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
if not self.model_update_lock.locked():
if self.server_args.dp_size == 1:
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
# There will be dp_size number of response from the detokenizer
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
else:
return False, "Another update is in progress. Please try again later."
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
await asyncio.sleep(1)
if obj.is_single:
self.abort_request(obj.rid)
else:
for rid in obj.rid:
self.abort_request(rid)
background_tasks = BackgroundTasks()
background_tasks.add_task(abort_request)
return background_tasks
def create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(60)
# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break
kill_child_process(include_self=True)
sys.exit(-1)
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
def convert_logprob_style(
self,
ret: dict,
return_logprob: bool,
top_logprobs_num: int,
return_text_in_logprobs: bool,
):
if return_logprob:
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["input_top_logprobs"],
return_text_in_logprobs,
)
)
ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
)
)
return ret
def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
):
# TODO(lianmin): This should run on DetokenizerManager
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
assert self.tokenizer is not None
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
for i, token_top_logprobs in enumerate(top_logprobs):
if token_top_logprobs:
top_logprobs[i] = self.detokenize_logprob_tokens(
token_top_logprobs, decode_to_text
)
return top_logprobs
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True