[Fix] Fix the case where prompt_len = 0 (#1593)
This commit is contained in:
@@ -626,9 +626,11 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
logits_output = None
|
logits_output = None
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = torch.full(
|
||||||
|
(batch.batch_size(),), self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids = [0] * len(batch.reqs)
|
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
|
|||||||
@@ -526,7 +526,7 @@ class TokenizerManager:
|
|||||||
async with self.model_update_lock:
|
async with self.model_update_lock:
|
||||||
# wait for the previous generation requests to finish
|
# wait for the previous generation requests to finish
|
||||||
while len(self.rid_to_state) > 0:
|
while len(self.rid_to_state) > 0:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.001)
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
self.model_update_result = asyncio.Future()
|
self.model_update_result = asyncio.Future()
|
||||||
result = await self.model_update_result
|
result = await self.model_update_result
|
||||||
|
|||||||
@@ -624,6 +624,6 @@ def broadcast_pyobj(
|
|||||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
|
|
||||||
serialized_data = bytes(tensor_data.tolist())
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||||
data = pickle.loads(serialized_data)
|
data = pickle.loads(serialized_data)
|
||||||
return data
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user