Simplify tokenizer manager (#1904)
This commit is contained in:
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
@@ -51,11 +52,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
ProfileReq,
|
||||
RewardReqConv,
|
||||
RewardReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
@@ -157,7 +155,7 @@ class TokenizerManager:
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
@@ -172,122 +170,54 @@ class TokenizerManager:
|
||||
"Please add `--is-embedding` when launching the server or try another model."
|
||||
)
|
||||
|
||||
obj.post_init()
|
||||
obj.normalize_batch_and_arguments()
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
async for response in self._handle_single_request(obj, request):
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
async for response in self._wait_one_response(obj, request):
|
||||
yield response
|
||||
else:
|
||||
async for response in self._handle_batch_request(obj, request):
|
||||
yield response
|
||||
|
||||
async def _send_single_request(
|
||||
async def _tokenize_one_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
index: Optional[int] = None,
|
||||
input_id_index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
):
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
if index is None:
|
||||
rid = obj.rid
|
||||
if isinstance(obj, RewardReqInput):
|
||||
input_text = self._apply_chat_template(obj.conv)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = obj.text if obj.text is not None else None
|
||||
input_ids = obj.input_ids
|
||||
"""Tokenize one request."""
|
||||
# Tokenize
|
||||
input_text = obj.text
|
||||
if obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
else:
|
||||
rid = obj.rid[index]
|
||||
if isinstance(obj, RewardReqInput):
|
||||
input_text = self._apply_chat_template(obj.conv[input_id_index])
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[input_id_index]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = (
|
||||
obj.text[input_id_index] if obj.text is not None else None
|
||||
)
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[index]
|
||||
logprob_start_len = obj.logprob_start_len[index]
|
||||
top_logprobs_num = obj.top_logprobs_num[index]
|
||||
|
||||
self._validate_input_length(input_ids)
|
||||
|
||||
else: # A prefill request to cache the common prompt for parallel sampling
|
||||
assert self.is_generation
|
||||
if obj.text is not None:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_text = obj.text
|
||||
rid = obj.rid[0]
|
||||
if self.tokenizer is not None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
assert obj.input_ids is not None
|
||||
input_ids = obj.input_ids
|
||||
if isinstance(obj.input_ids, list) and isinstance(
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
rid = obj.rid[0]
|
||||
else:
|
||||
input_text = None
|
||||
if isinstance(obj.input_ids, list) and isinstance(
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
rid = obj.rid[0]
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[0], input_text or input_ids, obj
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
|
||||
# Send to the controller
|
||||
if self.is_generation:
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
# Parse sampling parameters
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
|
||||
# Build return object
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
@@ -296,219 +226,126 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
(
|
||||
obj.lora_path[input_id_index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
obj.lora_path
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
return rid, input_ids
|
||||
return tokenized_obj
|
||||
|
||||
async def _handle_single_request(
|
||||
async def _wait_one_response(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
input_id_index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
rid, input_ids = await self._send_single_request(
|
||||
obj,
|
||||
index,
|
||||
input_id_index=input_id_index,
|
||||
is_cache_for_prefill=is_cache_for_prefill,
|
||||
)
|
||||
|
||||
# Recv results
|
||||
"""Wait for the response of one request."""
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
self.rid_to_state[obj.rid] = state
|
||||
|
||||
if not is_cache_for_prefill:
|
||||
async for response in self._wait_for_response(state, obj, rid, request):
|
||||
yield response
|
||||
else:
|
||||
await state.event.wait()
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
if self.is_generation:
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common prefix
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
async for input_id in self._handle_single_request(
|
||||
obj,
|
||||
request,
|
||||
index=i,
|
||||
input_id_index=i,
|
||||
is_cache_for_prefill=True,
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
if input_id_result is not None:
|
||||
obj.input_ids = input_id_result
|
||||
else:
|
||||
parallel_sample_num = 1
|
||||
|
||||
# First send out all requests
|
||||
generators = []
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
if j == 0 and parallel_sample_num != 1:
|
||||
continue
|
||||
index = i * parallel_sample_num + j
|
||||
if parallel_sample_num != 1:
|
||||
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||
index += batch_size - 1 - i
|
||||
|
||||
rid, _ = await self._send_single_request(
|
||||
obj, index, input_id_index=i, is_cache_for_prefill=False
|
||||
)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
generators.append(
|
||||
self._wait_for_response(
|
||||
state,
|
||||
obj,
|
||||
rid,
|
||||
request,
|
||||
index=index,
|
||||
response_index=len(generators),
|
||||
)
|
||||
)
|
||||
|
||||
# Then process the responses based on streaming option
|
||||
is_stream = hasattr(obj, "stream") and obj.stream
|
||||
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
output_list = [None] * len(tasks)
|
||||
|
||||
# Fetch results
|
||||
while tasks:
|
||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
cur_index = tasks.index(task)
|
||||
|
||||
try:
|
||||
result = task.result()
|
||||
|
||||
if is_stream:
|
||||
yield result
|
||||
else:
|
||||
output_list[result["index"]] = result
|
||||
|
||||
tasks[cur_index] = asyncio.create_task(
|
||||
generators[cur_index].__anext__()
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
del generators[cur_index]
|
||||
del tasks[cur_index]
|
||||
|
||||
if not is_stream:
|
||||
yield output_list
|
||||
|
||||
def _validate_input_length(self, input_ids: List[int]):
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
def _get_sampling_params(self, sampling_params_data: dict):
|
||||
sampling_params = SamplingParams(**sampling_params_data)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
return sampling_params
|
||||
|
||||
def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
|
||||
if isinstance(conv, str):
|
||||
return conv
|
||||
elif isinstance(conv, list):
|
||||
if isinstance(conv[0], str):
|
||||
return conv
|
||||
else:
|
||||
return self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
state: ReqState,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
rid: str,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
response_index: int = 0,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
continue
|
||||
|
||||
if self.is_generation:
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob if index is None else obj.return_logprob[index],
|
||||
(
|
||||
obj.top_logprobs_num
|
||||
if index is None
|
||||
else obj.top_logprobs_num[index]
|
||||
),
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
|
||||
else: # isinstance(obj, (EmbeddingReqInput,))
|
||||
out = state.out_list[-1]
|
||||
|
||||
out["index"] = response_index
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
# Log requests
|
||||
if self.server_args.log_requests:
|
||||
# Log requests
|
||||
logger.info(f"in={obj}, out={out}")
|
||||
del self.rid_to_state[rid]
|
||||
del self.rid_to_state[obj.rid]
|
||||
yield out
|
||||
break
|
||||
|
||||
state.event.clear()
|
||||
yield out
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
|
||||
generators = []
|
||||
rids = []
|
||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||
# Send all requests
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
|
||||
# Tokenize all requests
|
||||
objs = [obj[i] for i in range(batch_size)]
|
||||
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
|
||||
|
||||
# Cache the common prefix for parallel sampling
|
||||
for i in range(batch_size):
|
||||
tmp_obj = copy.copy(objs[i])
|
||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
||||
tokenized_obj.sampling_params.max_new_tokens = 0
|
||||
tokenized_obj.stream = False
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
await self._wait_one_response(tmp_obj, request).__anext__()
|
||||
|
||||
# Expand requests, assign new rids for them, and send them
|
||||
for i in range(batch_size):
|
||||
for _ in range(obj.parallel_sample_num):
|
||||
tmp_obj = copy.copy(objs[i])
|
||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
|
||||
# Wait for all requests
|
||||
is_stream = hasattr(obj, "stream") and obj.stream
|
||||
if not is_stream:
|
||||
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
|
||||
yield outputs
|
||||
else:
|
||||
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
||||
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
||||
while task_map:
|
||||
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
gen = task_map.pop(task)
|
||||
try:
|
||||
result = task.result()
|
||||
result["index"] = rid_to_index[result["meta_info"]["id"]]
|
||||
yield result
|
||||
new_task = asyncio.create_task(gen.__anext__())
|
||||
task_map[new_task] = gen
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
Reference in New Issue
Block a user