fix black in pre-commit (#1940)

This commit is contained in:
Chayenne
2024-11-07 15:42:47 -08:00
committed by GitHub
parent dca87ec348
commit c77c1e05ba
29 changed files with 641 additions and 508 deletions

View File

@@ -215,7 +215,7 @@ class TokenizerManager:
logprob_start_len,
top_logprobs_num,
obj.stream,
obj.lora_path
obj.lora_path,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
@@ -290,7 +290,9 @@ class TokenizerManager:
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
tokenized_objs = await asyncio.gather(
*(self._tokenize_one_request(obj) for obj in objs)
)
# Cache the common prefix for parallel sampling
for i in range(batch_size):
@@ -322,7 +324,9 @@ class TokenizerManager:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
gen = task_map.pop(task)
@@ -367,7 +371,7 @@ class TokenizerManager:
if self.server_args.dp_size == 1:
res = await self.mem_pool_size
return res.size
else: # self.server_args.dp_size > 1
else: # self.server_args.dp_size > 1
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
@@ -399,7 +403,7 @@ class TokenizerManager:
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result
@@ -470,7 +474,7 @@ class TokenizerManager:
if isinstance(recv_obj, UpdateWeightReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
@@ -479,7 +483,7 @@ class TokenizerManager:
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size: