Simplify batch result resolution (#1735)

This commit is contained in:
Lianmin Zheng
2024-10-20 19:47:14 -07:00
committed by GitHub
parent e12358dc91
commit b121bc03a3
5 changed files with 64 additions and 90 deletions

View File

@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
# Create future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
# Init future mappings
self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
# Launch a thread
self.future_event_map = dict()
self.forward_queue = Queue()
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)
model_worker_batch, future_token_ids_ct = self.input_queue.get()
# Resolve future tokens in the input
tic2 = time.time()
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch
)
# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()
# Set the result
next_token_ids = next_token_ids.tolist()
assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
if False:
tic3 = time.time()
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1
# Push a new batch to the queue
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + bs),
-(self.future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
return None, future_next_token_ids
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)