[router][grpc] Support parallel queue puts in grpc_request_manager and remove mutex for grpc_client (#11798)
This commit is contained in:
@@ -443,10 +443,11 @@ class GrpcRequestManager:
|
||||
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Check for pause
|
||||
async with self.is_pause_cond:
|
||||
while self.is_pause:
|
||||
await self.is_pause_cond.wait()
|
||||
# Check for pause (optimized: check flag before acquiring lock)
|
||||
if self.is_pause:
|
||||
async with self.is_pause_cond:
|
||||
while self.is_pause:
|
||||
await self.is_pause_cond.wait()
|
||||
|
||||
# Handle different output types
|
||||
if isinstance(recv_obj, BatchTokenIDOutput):
|
||||
@@ -531,6 +532,11 @@ class GrpcRequestManager:
|
||||
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
||||
"""Handle batch generation output from scheduler."""
|
||||
# Collect all queue.put() tasks for parallel execution
|
||||
put_tasks = []
|
||||
cleanup_tasks = []
|
||||
now = time.time()
|
||||
|
||||
# Process each request in the batch
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
@@ -544,7 +550,6 @@ class GrpcRequestManager:
|
||||
continue
|
||||
|
||||
# Update metrics
|
||||
now = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = now
|
||||
state.last_time = now
|
||||
@@ -638,7 +643,8 @@ class GrpcRequestManager:
|
||||
if output_data["token_ids"]:
|
||||
state.output_ids.extend(output_data["token_ids"])
|
||||
|
||||
await state.out_queue.put(output_data)
|
||||
# Add queue.put() to parallel task list
|
||||
put_tasks.append(state.out_queue.put(output_data))
|
||||
|
||||
# Handle completion
|
||||
if output_data["finished"]:
|
||||
@@ -648,12 +654,16 @@ class GrpcRequestManager:
|
||||
state.event.set()
|
||||
|
||||
# Remove from tracking after a delay
|
||||
async def cleanup():
|
||||
async def cleanup(request_id):
|
||||
await asyncio.sleep(5.0)
|
||||
if rid in self.rid_to_state:
|
||||
del self.rid_to_state[rid]
|
||||
if request_id in self.rid_to_state:
|
||||
del self.rid_to_state[request_id]
|
||||
|
||||
asyncio.create_task(cleanup())
|
||||
cleanup_tasks.append(asyncio.create_task(cleanup(rid)))
|
||||
|
||||
# Execute all queue.put() operations in parallel
|
||||
if put_tasks:
|
||||
await asyncio.gather(*put_tasks, return_exceptions=True)
|
||||
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
||||
"""Handle batch embedding output from scheduler."""
|
||||
|
||||
Reference in New Issue
Block a user