[Refactor] Simplify io_struct and tokenizer_manager (#1549)
This commit is contained in:
@@ -159,58 +159,72 @@ class TokenizerManager:
|
||||
async for response in self._handle_batch_request(obj, request):
|
||||
yield response
|
||||
|
||||
async def _handle_single_request(
|
||||
async def _send_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
input_id_index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
not_use_index = index is None
|
||||
if index is None:
|
||||
rid = obj.rid
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = obj.text if obj.text is not None else None
|
||||
input_ids = obj.input_ids
|
||||
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
assert self.tokenizer is not None
|
||||
conv = obj.conv if not_use_index else obj.conv[index]
|
||||
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
assert self.tokenizer is not None
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data, obj
|
||||
)
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
else:
|
||||
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
||||
rid = obj.rid[index]
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv[index]
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[input_id_index]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = (
|
||||
obj.text[input_id_index] if obj.text is not None else None
|
||||
)
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], obj
|
||||
)
|
||||
return_logprob = obj.return_logprob[index]
|
||||
logprob_start_len = obj.logprob_start_len[index]
|
||||
top_logprobs_num = obj.top_logprobs_num[index]
|
||||
|
||||
self._validate_input_length(input_ids)
|
||||
|
||||
sampling_params = self._get_sampling_params(
|
||||
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data if not_use_index else obj.image_data[index], obj
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
logprob_start_len = (
|
||||
obj.logprob_start_len
|
||||
if not_use_index
|
||||
else obj.logprob_start_len[index]
|
||||
)
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
else obj.top_logprobs_num[index]
|
||||
)
|
||||
else: # A prefill request to cache the common prompt for parallel sampling
|
||||
assert self.is_generation
|
||||
if obj.text is not None:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[index]
|
||||
input_text = obj.text[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_text = obj.text
|
||||
@@ -224,7 +238,7 @@ class TokenizerManager:
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[index]
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
@@ -235,7 +249,7 @@ class TokenizerManager:
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[index]
|
||||
input_ids = obj.input_ids[input_id_index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
@@ -263,7 +277,7 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
obj.lora_path[input_id_index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
@@ -283,12 +297,30 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
return rid, input_ids
|
||||
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
input_id_index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
rid, input_ids = await self._send_single_request(
|
||||
obj,
|
||||
index,
|
||||
input_id_index=input_id_index,
|
||||
is_cache_for_prefill=is_cache_for_prefill,
|
||||
)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
if not is_cache_for_prefill:
|
||||
async for response in self._wait_for_response(state, obj, rid, request):
|
||||
yield response
|
||||
@@ -312,14 +344,16 @@ class TokenizerManager:
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
async for input_id in self._handle_single_request(
|
||||
obj, request, index=i, is_cache_for_prefill=True
|
||||
obj,
|
||||
request,
|
||||
index=i,
|
||||
input_id_index=i,
|
||||
is_cache_for_prefill=True,
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
if input_id_result is not None and len(input_id_result) > 1:
|
||||
if input_id_result is not None:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
else:
|
||||
parallel_sample_num = 1
|
||||
|
||||
@@ -333,69 +367,10 @@ class TokenizerManager:
|
||||
if parallel_sample_num != 1:
|
||||
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||
index += batch_size - 1 - i
|
||||
rid = obj.rid[index]
|
||||
if parallel_sample_num == 1:
|
||||
## select operation
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv[i]
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
else:
|
||||
assert obj.input_ids is not None
|
||||
if batch_size == 1:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], obj
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
rid, _ = await self._send_single_request(
|
||||
obj, index, input_id_index=i, is_cache_for_prefill=False
|
||||
)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
@@ -418,7 +393,7 @@ class TokenizerManager:
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
output_list = [None] * len(tasks)
|
||||
|
||||
# Recv results
|
||||
# Fetch results
|
||||
while tasks:
|
||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user