Faster overlap mode scheduler (#1738)
This commit is contained in:
@@ -55,7 +55,7 @@ class TpModelWorkerClient:
|
||||
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Launch a thread
|
||||
# Launch threads
|
||||
self.input_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
self.forward_stream = torch.cuda.Stream()
|
||||
@@ -64,6 +64,12 @@ class TpModelWorkerClient:
|
||||
)
|
||||
self.forward_thread.start()
|
||||
|
||||
self.copy_queue = Queue()
|
||||
self.copy_thread = threading.Thread(
|
||||
target=self.copy_thread_func,
|
||||
)
|
||||
self.copy_thread.start()
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
@@ -86,7 +92,10 @@ class TpModelWorkerClient:
|
||||
@torch.inference_mode()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
self.has_inflight_batch = False
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
self.has_inflight_batch = True
|
||||
self.launch_event = threading.Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
@@ -100,6 +109,7 @@ class TpModelWorkerClient:
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
self.launch_event.set()
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
@@ -113,13 +123,23 @@ class TpModelWorkerClient:
|
||||
torch.int32
|
||||
)
|
||||
|
||||
# Set the result
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
assert logits_output.next_token_logprobs is None, "Not supported"
|
||||
self.output_queue.put((None, next_token_ids))
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event = torch.cuda.Event(blocking=True)
|
||||
copy_event.record()
|
||||
self.copy_queue.put((copy_event, next_token_ids))
|
||||
|
||||
def copy_thread_func(self):
|
||||
while True:
|
||||
copy_event, next_token_ids = self.copy_queue.get()
|
||||
while not copy_event.query():
|
||||
time.sleep(1e-5)
|
||||
self.output_queue.put((None, next_token_ids.tolist()))
|
||||
|
||||
def resulve_batch_result(self, bid: int):
|
||||
logits_output, next_token_ids = self.output_queue.get()
|
||||
if self.has_inflight_batch:
|
||||
# Wait until the batch is launched
|
||||
self.launch_event.wait()
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
|
||||
Reference in New Issue
Block a user