Fix package loss for small models (#2717)
Co-authored-by: sdli1995 < mmlmonkey@163.com>
This commit is contained in:
@@ -1364,11 +1364,11 @@ class Scheduler:
|
||||
embeddings = []
|
||||
prompt_tokens = []
|
||||
for req in reqs:
|
||||
assert req.finished()
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(req.finished_reason.to_json())
|
||||
embeddings.append(req.embedding)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
if req.finished():
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(req.finished_reason.to_json())
|
||||
embeddings.append(req.embedding)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
||||
)
|
||||
|
||||
@@ -222,10 +222,8 @@ class TokenizerManager:
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
async for response in self._wait_one_response(
|
||||
obj, request, created_time
|
||||
):
|
||||
self._send_one_request(obj, tokenized_obj, created_time)
|
||||
async for response in self._wait_one_response(obj, request):
|
||||
yield response
|
||||
else:
|
||||
async for response in self._handle_batch_request(
|
||||
@@ -306,16 +304,24 @@ class TokenizerManager:
|
||||
|
||||
return tokenized_obj
|
||||
|
||||
def _send_one_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
async def _wait_one_response(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
"""Wait for the response of one request."""
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
state = self.rid_to_state[obj.rid]
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -361,10 +367,8 @@ class TokenizerManager:
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
generators.append(
|
||||
self._wait_one_response(tmp_obj, request, created_time)
|
||||
)
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
@@ -389,10 +393,8 @@ class TokenizerManager:
|
||||
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
||||
tokenized_obj.sampling_params.max_new_tokens = 0
|
||||
tokenized_obj.stream = False
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
await self._wait_one_response(
|
||||
tmp_obj, request, created_time
|
||||
).__anext__()
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
await self._wait_one_response(tmp_obj, request).__anext__()
|
||||
|
||||
# Expand requests, assign new rids for them, and send them
|
||||
for i in range(batch_size):
|
||||
@@ -400,10 +402,8 @@ class TokenizerManager:
|
||||
tmp_obj = copy.copy(objs[i])
|
||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
generators.append(
|
||||
self._wait_one_response(tmp_obj, request, created_time)
|
||||
)
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
|
||||
# Wait for all requests
|
||||
|
||||
Reference in New Issue
Block a user