fix black in pre-commit (#1940)
This commit is contained in:
@@ -215,7 +215,7 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path
|
||||
obj.lora_path,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -290,7 +290,9 @@ class TokenizerManager:
|
||||
|
||||
# Tokenize all requests
|
||||
objs = [obj[i] for i in range(batch_size)]
|
||||
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
|
||||
tokenized_objs = await asyncio.gather(
|
||||
*(self._tokenize_one_request(obj) for obj in objs)
|
||||
)
|
||||
|
||||
# Cache the common prefix for parallel sampling
|
||||
for i in range(batch_size):
|
||||
@@ -322,7 +324,9 @@ class TokenizerManager:
|
||||
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
||||
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
||||
while task_map:
|
||||
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
done, _ = await asyncio.wait(
|
||||
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for task in done:
|
||||
gen = task_map.pop(task)
|
||||
@@ -367,7 +371,7 @@ class TokenizerManager:
|
||||
if self.server_args.dp_size == 1:
|
||||
res = await self.mem_pool_size
|
||||
return res.size
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.mem_pool_size_tmp = []
|
||||
res = await self.mem_pool_size
|
||||
ret = [r.size for r in res]
|
||||
@@ -399,7 +403,7 @@ class TokenizerManager:
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
@@ -470,7 +474,7 @@ class TokenizerManager:
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
@@ -479,7 +483,7 @@ class TokenizerManager:
|
||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.mem_pool_size.set_result(recv_obj)
|
||||
else: # self.sever_args.dp_size > 1
|
||||
else: # self.sever_args.dp_size > 1
|
||||
self.mem_pool_size_tmp.append(recv_obj)
|
||||
# set future if the all results are received
|
||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||
|
||||
Reference in New Issue
Block a user