Simplify batch result resolution (#1735)
This commit is contained in:
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
|
||||
self.max_running_requests = self.worker.max_running_requests
|
||||
self.device = self.worker.device
|
||||
|
||||
# Create future mappings
|
||||
self.future_logits_output_dict = dict()
|
||||
self.future_logits_output_ct = 0
|
||||
# Init future mappings
|
||||
self.future_token_ids_ct = 0
|
||||
self.future_token_ids_limit = self.max_running_requests * 3
|
||||
self.future_token_ids_map = torch.empty(
|
||||
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.future_token_ids_limit = self.max_running_requests * 3
|
||||
self.future_token_ids_output = dict()
|
||||
|
||||
# Launch a thread
|
||||
self.future_event_map = dict()
|
||||
self.forward_queue = Queue()
|
||||
self.input_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
self.forward_stream = torch.cuda.Stream()
|
||||
self.forward_thread = threading.Thread(
|
||||
target=self.forward_thread_func,
|
||||
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
tic1 = time.time()
|
||||
model_worker_batch, future_logits_output, future_next_token_ids = (
|
||||
self.forward_queue.get()
|
||||
)
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
tic2 = time.time()
|
||||
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
|
||||
model_worker_batch
|
||||
)
|
||||
|
||||
# Set future values
|
||||
if model_worker_batch.return_logprob:
|
||||
self.future_logits_output_dict[future_logits_output] = logits_output
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(future_token_ids_ct + bs),
|
||||
-(future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||
torch.int32
|
||||
)
|
||||
self.future_token_ids_output[model_worker_batch.bid] = (
|
||||
next_token_ids.tolist()
|
||||
)
|
||||
self.future_event_map[model_worker_batch.bid].set()
|
||||
|
||||
# Set the result
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
assert logits_output.next_token_logprobs is None, "Not supported"
|
||||
self.output_queue.put((None, next_token_ids))
|
||||
|
||||
if False:
|
||||
tic3 = time.time()
|
||||
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
|
||||
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
||||
)
|
||||
|
||||
def resolve_future_token_ids(self, bid: int):
|
||||
self.future_event_map[bid].wait()
|
||||
ret = self.future_token_ids_output[bid]
|
||||
del self.future_event_map[bid]
|
||||
return ret
|
||||
|
||||
def resolve_future_logits_output(self, future_obj):
|
||||
return self.future_logits_output_dict.pop(future_obj)
|
||||
def resulve_batch_result(self, bid: int):
|
||||
logits_output, next_token_ids = self.output_queue.get()
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
# Allocate output future objects
|
||||
future_logits_output = self.future_logits_output_ct
|
||||
self.future_logits_output_ct += 1
|
||||
# Push a new batch to the queue
|
||||
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
future_next_token_ids = -torch.arange(
|
||||
self.future_token_ids_ct + 1,
|
||||
self.future_token_ids_ct + 1 + bs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(self.future_token_ids_ct + bs),
|
||||
-(self.future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_ct = (
|
||||
self.future_token_ids_ct + bs
|
||||
) % self.future_token_ids_limit
|
||||
ret = future_logits_output, future_next_token_ids
|
||||
|
||||
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
||||
self.forward_queue.put(
|
||||
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
||||
)
|
||||
return ret
|
||||
return None, future_next_token_ids
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
|
||||
Reference in New Issue
Block a user