[router][grpc] Support parallel queue puts in grpc_request_manager and remove mutex for grpc_client (#11798)

This commit is contained in:
Chang Su
2025-10-17 20:49:43 -07:00
committed by GitHub
parent 6c7c92eb02
commit ca240eefb4
4 changed files with 30 additions and 27 deletions

View File

@@ -443,10 +443,11 @@ class GrpcRequestManager:
recv_obj = await self.recv_from_scheduler.recv_pyobj()
self.last_receive_tstamp = time.time()
# Check for pause
async with self.is_pause_cond:
while self.is_pause:
await self.is_pause_cond.wait()
# Check for pause (optimized: check flag before acquiring lock)
if self.is_pause:
async with self.is_pause_cond:
while self.is_pause:
await self.is_pause_cond.wait()
# Handle different output types
if isinstance(recv_obj, BatchTokenIDOutput):
@@ -531,6 +532,11 @@ class GrpcRequestManager:
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
"""Handle batch generation output from scheduler."""
# Collect all queue.put() tasks for parallel execution
put_tasks = []
cleanup_tasks = []
now = time.time()
# Process each request in the batch
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
@@ -544,7 +550,6 @@ class GrpcRequestManager:
continue
# Update metrics
now = time.time()
if state.first_token_time == 0.0:
state.first_token_time = now
state.last_time = now
@@ -638,7 +643,8 @@ class GrpcRequestManager:
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
await state.out_queue.put(output_data)
# Add queue.put() to parallel task list
put_tasks.append(state.out_queue.put(output_data))
# Handle completion
if output_data["finished"]:
@@ -648,12 +654,16 @@ class GrpcRequestManager:
state.event.set()
# Remove from tracking after a delay
async def cleanup():
async def cleanup(request_id):
await asyncio.sleep(5.0)
if rid in self.rid_to_state:
del self.rid_to_state[rid]
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
asyncio.create_task(cleanup())
cleanup_tasks.append(asyncio.create_task(cleanup(rid)))
# Execute all queue.put() operations in parallel
if put_tasks:
await asyncio.gather(*put_tasks, return_exceptions=True)
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
"""Handle batch embedding output from scheduler."""