Enable overlap scheduler by default for the triton attention backend (#2105)

This commit is contained in:
Lianmin Zheng
2024-11-20 02:58:35 -08:00
committed by GitHub
parent 56a347f7d3
commit 722530fa01
6 changed files with 21 additions and 24 deletions

View File

@@ -94,10 +94,21 @@ class TpModelWorkerClient:
@torch.no_grad()
def forward_thread_func_(self):
batch_pt = 0
batch_lists = [None] * 2
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1
# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()