Simplify tokenizer manager (#1899)
This commit is contained in:
@@ -114,8 +114,7 @@ class GenerateReqInput:
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
else:
|
||||
# FIXME support cascade inference
|
||||
# first bs samples are used for caching the prefix for parallel sampling
|
||||
# The first bs samples are used for caching the prefix for parallel sampling
|
||||
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
||||
|
||||
if self.image_data is None:
|
||||
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
|
||||
# Whether it is a single request or a batch request
|
||||
is_single: bool = True
|
||||
|
||||
def post_init(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardReqInput:
|
||||
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
||||
conv: Union[List[List[Dict]], List[Dict]]
|
||||
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
|
||||
conv: RewardReqConv
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
|
||||
# Whether it is a single request or a batch request
|
||||
is_single: bool = True
|
||||
|
||||
def post_init(self):
|
||||
self.is_single = isinstance(self.conv[0], dict)
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
ProfileReq,
|
||||
RewardReqConv,
|
||||
RewardReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -89,6 +90,7 @@ class TokenizerManager:
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
|
||||
# Init inter-process communication
|
||||
@@ -114,6 +116,7 @@ class TokenizerManager:
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
)
|
||||
|
||||
# Create image processor placeholder
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
|
||||
@@ -165,7 +168,8 @@ class TokenizerManager:
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
|
||||
"This model does not appear to be an embedding model by default. "
|
||||
"Please add `--is-embedding` when launching the server or try another model."
|
||||
)
|
||||
|
||||
obj.post_init()
|
||||
@@ -187,12 +191,8 @@ class TokenizerManager:
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
if index is None:
|
||||
rid = obj.rid
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
if isinstance(obj, RewardReqInput):
|
||||
input_text = self._apply_chat_template(obj.conv)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text
|
||||
@@ -213,12 +213,8 @@ class TokenizerManager:
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
else:
|
||||
rid = obj.rid[index]
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv[index]
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
if isinstance(obj, RewardReqInput):
|
||||
input_text = self._apply_chat_template(obj.conv[input_id_index])
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[input_id_index]
|
||||
@@ -349,8 +345,9 @@ class TokenizerManager:
|
||||
async for response in self._wait_for_response(state, obj, rid, request):
|
||||
yield response
|
||||
else:
|
||||
assert self.is_generation
|
||||
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
||||
await state.event.wait()
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(
|
||||
@@ -456,6 +453,15 @@ class TokenizerManager:
|
||||
sampling_params.verify()
|
||||
return sampling_params
|
||||
|
||||
def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
|
||||
if isinstance(conv, str):
|
||||
return conv
|
||||
elif isinstance(conv, list):
|
||||
if isinstance(conv[0], str):
|
||||
return conv
|
||||
else:
|
||||
return self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
state: ReqState,
|
||||
@@ -491,12 +497,11 @@ class TokenizerManager:
|
||||
|
||||
out["index"] = response_index
|
||||
|
||||
# Log requests
|
||||
if self.server_args.log_requests and state.finished:
|
||||
logger.info(f"in={obj}, out={out}")
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
# Log requests
|
||||
if self.server_args.log_requests:
|
||||
logger.info(f"in={obj}, out={out}")
|
||||
del self.rid_to_state[rid]
|
||||
yield out
|
||||
break
|
||||
@@ -504,27 +509,6 @@ class TokenizerManager:
|
||||
state.event.clear()
|
||||
yield out
|
||||
|
||||
async def _wait_for_cache_prefill_response(
|
||||
self,
|
||||
state: ReqState,
|
||||
obj: GenerateReqInput,
|
||||
rid: str,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
@@ -553,6 +537,7 @@ class TokenizerManager:
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
self.mem_pool_size = asyncio.Future()
|
||||
|
||||
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
||||
if self.server_args.dp_size == 1:
|
||||
res = await self.mem_pool_size
|
||||
return res.size
|
||||
@@ -638,7 +623,7 @@ class TokenizerManager:
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
logger.info(
|
||||
f"gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
)
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
@@ -695,7 +680,6 @@ class TokenizerManager:
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -747,7 +731,7 @@ class TokenizerManager:
|
||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||
return [
|
||||
(logprob, token_id, token_text)
|
||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
|
||||
]
|
||||
|
||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
||||
|
||||
Reference in New Issue
Block a user