Use cuda event wait and synchronization instead of busy waiting (#2089)

This commit is contained in:
Lianmin Zheng
2024-11-19 00:21:46 -08:00
committed by GitHub
parent b110453802
commit b7a065eae3
6 changed files with 28 additions and 26 deletions

View File

@@ -1063,7 +1063,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
sampling_info=dataclasses.replace(self.sampling_info),
sampling_info=self.sampling_info,
)
def __str__(self):

View File

@@ -387,9 +387,6 @@ class Scheduler:
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = batch.seq_lens[0].item()
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))

View File

@@ -142,12 +142,12 @@ class TpModelWorker:
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_event: Optional[threading.Event] = None,
launch_done: Optional[threading.Event] = None,
):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_event:
launch_event.set()
if launch_done:
launch_done.set()
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids

View File

@@ -96,19 +96,22 @@ class TpModelWorkerClient:
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
model_worker_batch, future_token_ids_ct, compute_info_done = (
self.input_queue.get()
)
if not model_worker_batch:
break
self.launch_event = threading.Event()
copy_event = torch.cuda.Event()
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
compute_info_done.wait()
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_event
model_worker_batch, self.launch_done
)
# Update the future token ids map
@@ -133,15 +136,14 @@ class TpModelWorkerClient:
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event.record()
copy_done.record()
self.output_queue.put((copy_event, logits_output, next_token_ids))
self.output_queue.put((copy_done, logits_output, next_token_ids))
def resolve_batch_result(self, bid: int):
copy_event, logits_output, next_token_ids = self.output_queue.get()
while not copy_event.query():
time.sleep(1e-5)
self.launch_event.wait()
copy_done, logits_output, next_token_ids = self.output_queue.get()
copy_done.synchronize()
self.launch_done.wait()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
@@ -162,7 +164,11 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info
)
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
compute_info_done = torch.cuda.Event()
compute_info_done.record()
self.input_queue.put(
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
)
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)