Fix logprob in the overlapped mode (#1795)

This commit is contained in:
Lianmin Zheng
2024-10-25 11:06:57 -07:00
committed by GitHub
parent c555ce2ca2
commit e646c5901e
7 changed files with 62 additions and 29 deletions

View File

@@ -103,6 +103,8 @@ class TpModelWorkerClient:
while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True
self.launch_event = threading.Event()
@@ -122,19 +124,48 @@ class TpModelWorkerClient:
] = next_token_ids
# Copy results to the CPU
if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].to("cpu", non_blocking=True)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()
self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids))
self.copy_queue.put((copy_event, logits_output, next_token_ids))
def copy_thread_func(self):
while True:
copy_event, next_token_ids = self.copy_queue.get()
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query():
time.sleep(1e-5)
self.output_queue.put((None, next_token_ids.tolist()))
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
self.output_queue.put((logits_output, next_token_ids.tolist()))
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
@@ -172,3 +203,7 @@ class TpModelWorkerClient:
recv_req.model_path, recv_req.load_format
)
return success, message
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))