Make token mapping non-blocking in the overlapped mode (#1740)
This commit is contained in:
@@ -86,16 +86,15 @@ class TpModelWorkerClient:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
while True:
|
while True:
|
||||||
tic1 = time.time()
|
|
||||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
tic2 = time.time()
|
input_ids = model_worker_batch.input_ids
|
||||||
resolved_input_ids = model_worker_batch.input_ids
|
input_ids[:] = torch.where(
|
||||||
future_mask = resolved_input_ids < 0
|
input_ids < 0,
|
||||||
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||||
-resolved_input_ids[future_mask]
|
input_ids,
|
||||||
]
|
)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
@@ -119,15 +118,6 @@ class TpModelWorkerClient:
|
|||||||
assert logits_output.next_token_logprobs is None, "Not supported"
|
assert logits_output.next_token_logprobs is None, "Not supported"
|
||||||
self.output_queue.put((None, next_token_ids))
|
self.output_queue.put((None, next_token_ids))
|
||||||
|
|
||||||
if False:
|
|
||||||
tic3 = time.time()
|
|
||||||
self.acc_time_with_waiting += tic3 - tic1
|
|
||||||
self.acc_time_without_waiting += tic3 - tic2
|
|
||||||
if self.forward_queue.qsize() == 0:
|
|
||||||
logger.info(
|
|
||||||
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def resulve_batch_result(self, bid: int):
|
def resulve_batch_result(self, bid: int):
|
||||||
logits_output, next_token_ids = self.output_queue.get()
|
logits_output, next_token_ids = self.output_queue.get()
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
|||||||
Reference in New Issue
Block a user