Make constrained decoding work for overlap scheduler (#2095)

This commit is contained in:
Lianmin Zheng
2024-11-19 15:04:43 -08:00
committed by GitHub
parent 55bd97f3e5
commit ffd20fcd03
8 changed files with 119 additions and 95 deletions

View File

@@ -18,7 +18,6 @@ limitations under the License.
import dataclasses
import logging
import threading
import time
from queue import Queue
from typing import Optional
@@ -96,9 +95,7 @@ class TpModelWorkerClient:
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct, compute_info_done = (
self.input_queue.get()
)
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.launch_done = threading.Event()
@@ -109,7 +106,6 @@ class TpModelWorkerClient:
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
compute_info_done.wait()
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done
)
@@ -160,15 +156,16 @@ class TpModelWorkerClient:
return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_ = model_worker_batch.seq_lens[0].item()
# Push a new batch to the queue
model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info
)
compute_info_done = torch.cuda.Event()
compute_info_done.record()
self.input_queue.put(
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
model_worker_batch.sampling_info,
sampling_info_done=threading.Event(),
)
self.cur_sampling_info = model_worker_batch.sampling_info
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)