Use cuda event wait and synchronization instead of busy waiting (#2089)
This commit is contained in:
@@ -96,19 +96,22 @@ class TpModelWorkerClient:
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
model_worker_batch, future_token_ids_ct, compute_info_done = (
|
||||
self.input_queue.get()
|
||||
)
|
||||
if not model_worker_batch:
|
||||
break
|
||||
self.launch_event = threading.Event()
|
||||
copy_event = torch.cuda.Event()
|
||||
self.launch_done = threading.Event()
|
||||
copy_done = torch.cuda.Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||
|
||||
# Run forward
|
||||
compute_info_done.wait()
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch, self.launch_event
|
||||
model_worker_batch, self.launch_done
|
||||
)
|
||||
|
||||
# Update the future token ids map
|
||||
@@ -133,15 +136,14 @@ class TpModelWorkerClient:
|
||||
)
|
||||
)
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event.record()
|
||||
copy_done.record()
|
||||
|
||||
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
||||
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
||||
|
||||
def resolve_batch_result(self, bid: int):
|
||||
copy_event, logits_output, next_token_ids = self.output_queue.get()
|
||||
while not copy_event.query():
|
||||
time.sleep(1e-5)
|
||||
self.launch_event.wait()
|
||||
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
||||
copy_done.synchronize()
|
||||
self.launch_done.wait()
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
@@ -162,7 +164,11 @@ class TpModelWorkerClient:
|
||||
model_worker_batch.sampling_info = dataclasses.replace(
|
||||
model_worker_batch.sampling_info
|
||||
)
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||
compute_info_done = torch.cuda.Event()
|
||||
compute_info_done.record()
|
||||
self.input_queue.put(
|
||||
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
|
||||
)
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user