Simplify tokenizer manager (#1899)
This commit is contained in:
@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
|
|||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
|
||||||
```
|
```
|
||||||
|
|
||||||
If the chat template you are looking for is missing, you are welcome to contribute it.
|
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
|
||||||
Meanwhile, you can also temporarily register your chat template as follows:
|
|
||||||
|
## JSON Format
|
||||||
|
You can load the JSON format, which is defined by `conversation.py`.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows:
|
|||||||
```
|
```
|
||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Jinja Format
|
||||||
|
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
|
||||||
|
```
|
||||||
@@ -114,8 +114,7 @@ class GenerateReqInput:
|
|||||||
if self.parallel_sample_num == 1:
|
if self.parallel_sample_num == 1:
|
||||||
num = self.batch_size
|
num = self.batch_size
|
||||||
else:
|
else:
|
||||||
# FIXME support cascade inference
|
# The first bs samples are used for caching the prefix for parallel sampling
|
||||||
# first bs samples are used for caching the prefix for parallel sampling
|
|
||||||
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
||||||
|
|
||||||
if self.image_data is None:
|
if self.image_data is None:
|
||||||
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
|
|||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
|
||||||
|
# Whether it is a single request or a batch request
|
||||||
|
is_single: bool = True
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
|
|||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RewardReqInput:
|
class RewardReqInput:
|
||||||
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
|
||||||
conv: Union[List[List[Dict]], List[Dict]]
|
conv: RewardReqConv
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
|
||||||
|
# Whether it is a single request or a batch request
|
||||||
|
is_single: bool = True
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
self.is_single = isinstance(self.conv[0], dict)
|
self.is_single = isinstance(self.conv[0], dict)
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetMemPoolSizeReq,
|
GetMemPoolSizeReq,
|
||||||
GetMemPoolSizeReqOutput,
|
GetMemPoolSizeReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
|
RewardReqConv,
|
||||||
RewardReqInput,
|
RewardReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -89,6 +90,7 @@ class TokenizerManager:
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
):
|
):
|
||||||
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
@@ -114,6 +116,7 @@ class TokenizerManager:
|
|||||||
self.context_len = server_args.context_length or get_context_length(
|
self.context_len = server_args.context_length or get_context_length(
|
||||||
self.hf_config
|
self.hf_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create image processor placeholder
|
# Create image processor placeholder
|
||||||
self.image_processor = get_dummy_image_processor()
|
self.image_processor = get_dummy_image_processor()
|
||||||
|
|
||||||
@@ -165,7 +168,8 @@ class TokenizerManager:
|
|||||||
|
|
||||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
|
"This model does not appear to be an embedding model by default. "
|
||||||
|
"Please add `--is-embedding` when launching the server or try another model."
|
||||||
)
|
)
|
||||||
|
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
@@ -187,12 +191,8 @@ class TokenizerManager:
|
|||||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||||
if index is None:
|
if index is None:
|
||||||
rid = obj.rid
|
rid = obj.rid
|
||||||
if hasattr(obj, "conv"):
|
if isinstance(obj, RewardReqInput):
|
||||||
# reward model
|
input_text = self._apply_chat_template(obj.conv)
|
||||||
conv = obj.conv
|
|
||||||
input_text = self.tokenizer.apply_chat_template(
|
|
||||||
conv, tokenize=False
|
|
||||||
)
|
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
elif obj.input_ids is None:
|
elif obj.input_ids is None:
|
||||||
input_text = obj.text
|
input_text = obj.text
|
||||||
@@ -213,12 +213,8 @@ class TokenizerManager:
|
|||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
else:
|
else:
|
||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
if hasattr(obj, "conv"):
|
if isinstance(obj, RewardReqInput):
|
||||||
# reward model
|
input_text = self._apply_chat_template(obj.conv[input_id_index])
|
||||||
conv = obj.conv[index]
|
|
||||||
input_text = self.tokenizer.apply_chat_template(
|
|
||||||
conv, tokenize=False
|
|
||||||
)
|
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
elif obj.input_ids is None:
|
elif obj.input_ids is None:
|
||||||
input_text = obj.text[input_id_index]
|
input_text = obj.text[input_id_index]
|
||||||
@@ -349,8 +345,9 @@ class TokenizerManager:
|
|||||||
async for response in self._wait_for_response(state, obj, rid, request):
|
async for response in self._wait_for_response(state, obj, rid, request):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
assert self.is_generation
|
await state.event.wait()
|
||||||
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
assert state.finished
|
||||||
|
del self.rid_to_state[rid]
|
||||||
yield input_ids
|
yield input_ids
|
||||||
|
|
||||||
async def _handle_batch_request(
|
async def _handle_batch_request(
|
||||||
@@ -456,6 +453,15 @@ class TokenizerManager:
|
|||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
|
||||||
|
if isinstance(conv, str):
|
||||||
|
return conv
|
||||||
|
elif isinstance(conv, list):
|
||||||
|
if isinstance(conv[0], str):
|
||||||
|
return conv
|
||||||
|
else:
|
||||||
|
return self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||||
|
|
||||||
async def _wait_for_response(
|
async def _wait_for_response(
|
||||||
self,
|
self,
|
||||||
state: ReqState,
|
state: ReqState,
|
||||||
@@ -491,12 +497,11 @@ class TokenizerManager:
|
|||||||
|
|
||||||
out["index"] = response_index
|
out["index"] = response_index
|
||||||
|
|
||||||
# Log requests
|
|
||||||
if self.server_args.log_requests and state.finished:
|
|
||||||
logger.info(f"in={obj}, out={out}")
|
|
||||||
|
|
||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
|
# Log requests
|
||||||
|
if self.server_args.log_requests:
|
||||||
|
logger.info(f"in={obj}, out={out}")
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
yield out
|
yield out
|
||||||
break
|
break
|
||||||
@@ -504,27 +509,6 @@ class TokenizerManager:
|
|||||||
state.event.clear()
|
state.event.clear()
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
async def _wait_for_cache_prefill_response(
|
|
||||||
self,
|
|
||||||
state: ReqState,
|
|
||||||
obj: GenerateReqInput,
|
|
||||||
rid: str,
|
|
||||||
request: Optional[fastapi.Request] = None,
|
|
||||||
):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if request is not None and await request.is_disconnected():
|
|
||||||
for rid in obj.rid:
|
|
||||||
self.abort_request(rid)
|
|
||||||
raise ValueError(f"Abort request {rid}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert state.finished
|
|
||||||
del self.rid_to_state[rid]
|
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
req = FlushCacheReq()
|
req = FlushCacheReq()
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
@@ -553,6 +537,7 @@ class TokenizerManager:
|
|||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
self.mem_pool_size = asyncio.Future()
|
self.mem_pool_size = asyncio.Future()
|
||||||
|
|
||||||
|
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
res = await self.mem_pool_size
|
res = await self.mem_pool_size
|
||||||
return res.size
|
return res.size
|
||||||
@@ -638,7 +623,7 @@ class TokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
remain_num_req = len(self.rid_to_state)
|
remain_num_req = len(self.rid_to_state)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"gracefully exiting... remaining number of requests {remain_num_req}"
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||||
)
|
)
|
||||||
if remain_num_req > 0:
|
if remain_num_req > 0:
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
@@ -695,7 +680,6 @@ class TokenizerManager:
|
|||||||
"token_ids": recv_obj.output_ids[i],
|
"token_ids": recv_obj.output_ids[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
out_dict = {
|
out_dict = {
|
||||||
@@ -747,7 +731,7 @@ class TokenizerManager:
|
|||||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||||
return [
|
return [
|
||||||
(logprob, token_id, token_text)
|
(logprob, token_id, token_text)
|
||||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
|
||||||
]
|
]
|
||||||
|
|
||||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
||||||
|
|||||||
Reference in New Issue
Block a user