[Refactor] Simplify io_struct and tokenizer_manager (#1549)
This commit is contained in:
@@ -36,7 +36,7 @@ class GenerateReqInput:
|
|||||||
# See also python/sglang/srt/utils.py:load_image.
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params. See descriptions below.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
@@ -55,28 +55,47 @@ class GenerateReqInput:
|
|||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
# Whether it is a single request or a batch request
|
|
||||||
is_single: bool = True
|
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
):
|
):
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
|
|
||||||
if (
|
self.is_single = False
|
||||||
isinstance(self.sampling_params, dict)
|
|
||||||
and self.sampling_params.get("n", 1) != 1
|
|
||||||
):
|
|
||||||
is_single = False
|
|
||||||
else:
|
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
is_single = isinstance(self.text, str)
|
if isinstance(self.text, str):
|
||||||
|
self.is_single = True
|
||||||
|
self.batch_size = 1
|
||||||
else:
|
else:
|
||||||
is_single = isinstance(self.input_ids[0], int)
|
self.batch_size = len(self.text)
|
||||||
self.is_single = is_single
|
else:
|
||||||
|
if isinstance(self.input_ids[0], int):
|
||||||
|
self.is_single = True
|
||||||
|
self.batch_size = 1
|
||||||
|
else:
|
||||||
|
self.batch_size = len(self.input_ids)
|
||||||
|
|
||||||
if is_single:
|
if self.sampling_params is None:
|
||||||
|
self.parallel_sample_num = 1
|
||||||
|
if isinstance(self.sampling_params, dict):
|
||||||
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||||
|
else: # isinstance(self.sampling_params, list):
|
||||||
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
||||||
|
for sp in self.sampling_params:
|
||||||
|
# TODO cope with the case that the parallel_sample_num is different for different samples
|
||||||
|
assert self.parallel_sample_num == sp.get(
|
||||||
|
"n", 1
|
||||||
|
), "The parallel_sample_num should be the same for all samples in sample params."
|
||||||
|
|
||||||
|
if self.parallel_sample_num > 1:
|
||||||
|
if self.is_single:
|
||||||
|
self.is_single = False
|
||||||
|
if self.text is not None:
|
||||||
|
self.text = [self.text]
|
||||||
|
if self.input_ids is not None:
|
||||||
|
self.input_ids = [self.input_ids]
|
||||||
|
|
||||||
|
if self.is_single:
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = {}
|
self.sampling_params = {}
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
@@ -88,79 +107,54 @@ class GenerateReqInput:
|
|||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
else:
|
else:
|
||||||
parallel_sample_num_list = []
|
if self.parallel_sample_num == 1:
|
||||||
if isinstance(self.sampling_params, dict):
|
num = self.batch_size
|
||||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
|
||||||
elif isinstance(self.sampling_params, list):
|
|
||||||
for sp in self.sampling_params:
|
|
||||||
parallel_sample_num = sp.get("n", 1)
|
|
||||||
parallel_sample_num_list.append(parallel_sample_num)
|
|
||||||
parallel_sample_num = max(parallel_sample_num_list)
|
|
||||||
all_equal = all(
|
|
||||||
element == parallel_sample_num
|
|
||||||
for element in parallel_sample_num_list
|
|
||||||
)
|
|
||||||
if parallel_sample_num > 1 and (not all_equal):
|
|
||||||
# TODO cope with the case that the parallel_sample_num is different for different samples
|
|
||||||
raise ValueError(
|
|
||||||
"The parallel_sample_num should be the same for all samples in sample params."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
parallel_sample_num = 1
|
# FIXME support cascade inference
|
||||||
self.parallel_sample_num = parallel_sample_num
|
# first bs samples are used for caching the prefix for parallel sampling
|
||||||
|
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
||||||
if parallel_sample_num != 1:
|
|
||||||
# parallel sampling +1 represents the original prefill stage
|
|
||||||
num = parallel_sample_num + 1
|
|
||||||
if isinstance(self.text, list):
|
|
||||||
# suppot batch operation
|
|
||||||
self.batch_size = len(self.text)
|
|
||||||
num = num * len(self.text)
|
|
||||||
elif isinstance(self.input_ids, list) and isinstance(
|
|
||||||
self.input_ids[0], list
|
|
||||||
):
|
|
||||||
self.batch_size = len(self.input_ids)
|
|
||||||
num = num * len(self.input_ids)
|
|
||||||
else:
|
|
||||||
self.batch_size = 1
|
|
||||||
else:
|
|
||||||
# support select operation
|
|
||||||
num = len(self.text) if self.text is not None else len(self.input_ids)
|
|
||||||
self.batch_size = num
|
|
||||||
|
|
||||||
if self.image_data is None:
|
if self.image_data is None:
|
||||||
self.image_data = [None] * num
|
self.image_data = [None] * num
|
||||||
elif not isinstance(self.image_data, list):
|
elif not isinstance(self.image_data, list):
|
||||||
self.image_data = [self.image_data] * num
|
self.image_data = [self.image_data] * num
|
||||||
elif isinstance(self.image_data, list):
|
elif isinstance(self.image_data, list):
|
||||||
# multi-image with n > 1
|
# FIXME incorrect order for duplication
|
||||||
self.image_data = self.image_data * num
|
self.image_data = self.image_data * num
|
||||||
|
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = [{}] * num
|
self.sampling_params = [{}] * num
|
||||||
elif not isinstance(self.sampling_params, list):
|
elif not isinstance(self.sampling_params, list):
|
||||||
self.sampling_params = [self.sampling_params] * num
|
self.sampling_params = [self.sampling_params] * num
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||||
else:
|
else:
|
||||||
if not isinstance(self.rid, list):
|
assert isinstance(self.rid, list), "The rid should be a list."
|
||||||
raise ValueError("The rid should be a list.")
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
if self.return_logprob is None:
|
if self.return_logprob is None:
|
||||||
self.return_logprob = [False] * num
|
self.return_logprob = [False] * num
|
||||||
elif not isinstance(self.return_logprob, list):
|
elif not isinstance(self.return_logprob, list):
|
||||||
self.return_logprob = [self.return_logprob] * num
|
self.return_logprob = [self.return_logprob] * num
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
if self.logprob_start_len is None:
|
if self.logprob_start_len is None:
|
||||||
self.logprob_start_len = [-1] * num
|
self.logprob_start_len = [-1] * num
|
||||||
elif not isinstance(self.logprob_start_len, list):
|
elif not isinstance(self.logprob_start_len, list):
|
||||||
self.logprob_start_len = [self.logprob_start_len] * num
|
self.logprob_start_len = [self.logprob_start_len] * num
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = [0] * num
|
self.top_logprobs_num = [0] * num
|
||||||
elif not isinstance(self.top_logprobs_num, list):
|
elif not isinstance(self.top_logprobs_num, list):
|
||||||
self.top_logprobs_num = [self.top_logprobs_num] * num
|
self.top_logprobs_num = [self.top_logprobs_num] * num
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -199,8 +193,6 @@ class EmbeddingReqInput:
|
|||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
|
||||||
is_single: bool = True
|
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
@@ -255,8 +247,6 @@ class RewardReqInput:
|
|||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
|
||||||
is_single: bool = True
|
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
self.is_single = isinstance(self.conv[0], dict)
|
self.is_single = isinstance(self.conv[0], dict)
|
||||||
|
|
||||||
|
|||||||
@@ -159,58 +159,72 @@ class TokenizerManager:
|
|||||||
async for response in self._handle_batch_request(obj, request):
|
async for response in self._handle_batch_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def _handle_single_request(
|
async def _send_single_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
|
input_id_index: Optional[int] = None,
|
||||||
is_cache_for_prefill: Optional[bool] = False,
|
is_cache_for_prefill: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||||
not_use_index = index is None
|
if index is None:
|
||||||
|
rid = obj.rid
|
||||||
rid = obj.rid if not_use_index else obj.rid[index]
|
|
||||||
input_text = obj.text if not_use_index else obj.text[index]
|
|
||||||
if hasattr(obj, "conv"):
|
if hasattr(obj, "conv"):
|
||||||
# reward model
|
# reward model
|
||||||
assert self.tokenizer is not None
|
conv = obj.conv
|
||||||
conv = obj.conv if not_use_index else obj.conv[index]
|
input_text = self.tokenizer.apply_chat_template(
|
||||||
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
conv, tokenize=False
|
||||||
|
)
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
elif obj.input_ids is None:
|
elif obj.input_ids is None:
|
||||||
assert self.tokenizer is not None
|
input_text = obj.text
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
input_text = obj.text if obj.text is not None else None
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
|
||||||
|
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||||
|
if self.is_generation:
|
||||||
|
image_inputs = await self.image_processor.process_images_async(
|
||||||
|
obj.image_data, obj
|
||||||
|
)
|
||||||
|
return_logprob = obj.return_logprob
|
||||||
|
logprob_start_len = obj.logprob_start_len
|
||||||
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
|
else:
|
||||||
|
rid = obj.rid[index]
|
||||||
|
if hasattr(obj, "conv"):
|
||||||
|
# reward model
|
||||||
|
conv = obj.conv[index]
|
||||||
|
input_text = self.tokenizer.apply_chat_template(
|
||||||
|
conv, tokenize=False
|
||||||
|
)
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
elif obj.input_ids is None:
|
||||||
|
input_text = obj.text[input_id_index]
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
else:
|
||||||
|
input_text = (
|
||||||
|
obj.text[input_id_index] if obj.text is not None else None
|
||||||
|
)
|
||||||
|
input_ids = obj.input_ids[input_id_index]
|
||||||
|
|
||||||
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
|
if self.is_generation:
|
||||||
|
image_inputs = await self.image_processor.process_images_async(
|
||||||
|
obj.image_data[index], obj
|
||||||
|
)
|
||||||
|
return_logprob = obj.return_logprob[index]
|
||||||
|
logprob_start_len = obj.logprob_start_len[index]
|
||||||
|
top_logprobs_num = obj.top_logprobs_num[index]
|
||||||
|
|
||||||
self._validate_input_length(input_ids)
|
self._validate_input_length(input_ids)
|
||||||
|
|
||||||
sampling_params = self._get_sampling_params(
|
|
||||||
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.is_generation:
|
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
|
||||||
obj.image_data if not_use_index else obj.image_data[index], obj
|
|
||||||
)
|
|
||||||
return_logprob = (
|
|
||||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
|
||||||
)
|
|
||||||
logprob_start_len = (
|
|
||||||
obj.logprob_start_len
|
|
||||||
if not_use_index
|
|
||||||
else obj.logprob_start_len[index]
|
|
||||||
)
|
|
||||||
top_logprobs_num = (
|
|
||||||
obj.top_logprobs_num
|
|
||||||
if not_use_index
|
|
||||||
else obj.top_logprobs_num[index]
|
|
||||||
)
|
|
||||||
else: # A prefill request to cache the common prompt for parallel sampling
|
else: # A prefill request to cache the common prompt for parallel sampling
|
||||||
assert self.is_generation
|
assert self.is_generation
|
||||||
if obj.text is not None:
|
if obj.text is not None:
|
||||||
if isinstance(obj.text, list):
|
if isinstance(obj.text, list):
|
||||||
input_text = obj.text[index]
|
input_text = obj.text[input_id_index]
|
||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
else:
|
else:
|
||||||
input_text = obj.text
|
input_text = obj.text
|
||||||
@@ -224,7 +238,7 @@ class TokenizerManager:
|
|||||||
obj.input_ids[0], list
|
obj.input_ids[0], list
|
||||||
):
|
):
|
||||||
# when obj["input_ids"] is List[List[int]]
|
# when obj["input_ids"] is List[List[int]]
|
||||||
input_ids = obj.input_ids[index]
|
input_ids = obj.input_ids[input_id_index]
|
||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
@@ -235,7 +249,7 @@ class TokenizerManager:
|
|||||||
obj.input_ids[0], list
|
obj.input_ids[0], list
|
||||||
):
|
):
|
||||||
# when obj["input_ids"] is List[List[int]]
|
# when obj["input_ids"] is List[List[int]]
|
||||||
input_ids = obj.input_ids[index]
|
input_ids = obj.input_ids[input_id_index]
|
||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
@@ -263,7 +277,7 @@ class TokenizerManager:
|
|||||||
top_logprobs_num,
|
top_logprobs_num,
|
||||||
obj.stream,
|
obj.stream,
|
||||||
(
|
(
|
||||||
obj.lora_path[index]
|
obj.lora_path[input_id_index]
|
||||||
if isinstance(obj.lora_path, list)
|
if isinstance(obj.lora_path, list)
|
||||||
else obj.lora_path
|
else obj.lora_path
|
||||||
),
|
),
|
||||||
@@ -283,12 +297,30 @@ class TokenizerManager:
|
|||||||
input_ids,
|
input_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
|
return rid, input_ids
|
||||||
|
|
||||||
|
async def _handle_single_request(
|
||||||
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
input_id_index: Optional[int] = None,
|
||||||
|
is_cache_for_prefill: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
rid, input_ids = await self._send_single_request(
|
||||||
|
obj,
|
||||||
|
index,
|
||||||
|
input_id_index=input_id_index,
|
||||||
|
is_cache_for_prefill=is_cache_for_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
# Recv results
|
# Recv results
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
if not is_cache_for_prefill:
|
if not is_cache_for_prefill:
|
||||||
async for response in self._wait_for_response(state, obj, rid, request):
|
async for response in self._wait_for_response(state, obj, rid, request):
|
||||||
yield response
|
yield response
|
||||||
@@ -312,14 +344,16 @@ class TokenizerManager:
|
|||||||
input_id_result = [] if obj.input_ids is None else None
|
input_id_result = [] if obj.input_ids is None else None
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
async for input_id in self._handle_single_request(
|
async for input_id in self._handle_single_request(
|
||||||
obj, request, index=i, is_cache_for_prefill=True
|
obj,
|
||||||
|
request,
|
||||||
|
index=i,
|
||||||
|
input_id_index=i,
|
||||||
|
is_cache_for_prefill=True,
|
||||||
):
|
):
|
||||||
if input_id_result is not None:
|
if input_id_result is not None:
|
||||||
input_id_result.append(input_id)
|
input_id_result.append(input_id)
|
||||||
if input_id_result is not None and len(input_id_result) > 1:
|
if input_id_result is not None:
|
||||||
obj.input_ids = input_id_result
|
obj.input_ids = input_id_result
|
||||||
elif input_id_result is not None:
|
|
||||||
obj.input_ids = input_id_result[0]
|
|
||||||
else:
|
else:
|
||||||
parallel_sample_num = 1
|
parallel_sample_num = 1
|
||||||
|
|
||||||
@@ -333,70 +367,11 @@ class TokenizerManager:
|
|||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||||
index += batch_size - 1 - i
|
index += batch_size - 1 - i
|
||||||
rid = obj.rid[index]
|
|
||||||
if parallel_sample_num == 1:
|
|
||||||
## select operation
|
|
||||||
if hasattr(obj, "conv"):
|
|
||||||
# reward model
|
|
||||||
conv = obj.conv[i]
|
|
||||||
input_text = self.tokenizer.apply_chat_template(
|
|
||||||
conv, tokenize=False
|
|
||||||
)
|
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
|
||||||
elif obj.input_ids is None:
|
|
||||||
input_text = obj.text[i]
|
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
|
||||||
else:
|
|
||||||
input_text = None
|
|
||||||
input_ids = obj.input_ids[i]
|
|
||||||
else:
|
|
||||||
assert obj.input_ids is not None
|
|
||||||
if batch_size == 1:
|
|
||||||
input_text = None
|
|
||||||
input_ids = obj.input_ids
|
|
||||||
else:
|
|
||||||
input_text = None
|
|
||||||
input_ids = obj.input_ids[i]
|
|
||||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
|
||||||
|
|
||||||
if self.is_generation:
|
rid, _ = await self._send_single_request(
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
obj, index, input_id_index=i, is_cache_for_prefill=False
|
||||||
obj.image_data[index], obj
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
|
||||||
rid,
|
|
||||||
input_text,
|
|
||||||
input_ids,
|
|
||||||
image_inputs,
|
|
||||||
sampling_params,
|
|
||||||
obj.return_logprob[index],
|
|
||||||
obj.logprob_start_len[index],
|
|
||||||
obj.top_logprobs_num[index],
|
|
||||||
obj.stream,
|
|
||||||
(
|
|
||||||
obj.lora_path[index]
|
|
||||||
if isinstance(obj.lora_path, list)
|
|
||||||
else obj.lora_path
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
|
||||||
rid,
|
|
||||||
input_text,
|
|
||||||
input_ids,
|
|
||||||
sampling_params,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(obj, RewardReqInput)
|
|
||||||
tokenized_obj = TokenizedRewardReqInput(
|
|
||||||
rid,
|
|
||||||
input_text,
|
|
||||||
input_ids,
|
|
||||||
sampling_params,
|
|
||||||
)
|
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
|
||||||
|
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
@@ -418,7 +393,7 @@ class TokenizerManager:
|
|||||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||||
output_list = [None] * len(tasks)
|
output_list = [None] * len(tasks)
|
||||||
|
|
||||||
# Recv results
|
# Fetch results
|
||||||
while tasks:
|
while tasks:
|
||||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user