From ca240eefb459e16603ec7698bf94bb4264a15c1b Mon Sep 17 00:00:00 2001 From: Chang Su Date: Fri, 17 Oct 2025 20:49:43 -0700 Subject: [PATCH] [router][grpc] Support parallel queue puts in grpc_request_manager and remove mutex for grpc_client (#11798) --- .../sglang/srt/grpc/grpc_request_manager.py | 30 ++++++++++++------- sgl-router/src/core/worker.rs | 18 +++++------ sgl-router/src/core/worker_builder.rs | 6 ++-- sgl-router/src/routers/grpc/utils.rs | 3 +- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/grpc/grpc_request_manager.py b/python/sglang/srt/grpc/grpc_request_manager.py index 81845388b..d22e4a576 100644 --- a/python/sglang/srt/grpc/grpc_request_manager.py +++ b/python/sglang/srt/grpc/grpc_request_manager.py @@ -443,10 +443,11 @@ class GrpcRequestManager: recv_obj = await self.recv_from_scheduler.recv_pyobj() self.last_receive_tstamp = time.time() - # Check for pause - async with self.is_pause_cond: - while self.is_pause: - await self.is_pause_cond.wait() + # Check for pause (optimized: check flag before acquiring lock) + if self.is_pause: + async with self.is_pause_cond: + while self.is_pause: + await self.is_pause_cond.wait() # Handle different output types if isinstance(recv_obj, BatchTokenIDOutput): @@ -531,6 +532,11 @@ class GrpcRequestManager: async def _handle_batch_output(self, batch_out: BatchTokenIDOutput): """Handle batch generation output from scheduler.""" + # Collect all queue.put() tasks for parallel execution + put_tasks = [] + cleanup_tasks = [] + now = time.time() + # Process each request in the batch for i, rid in enumerate(batch_out.rids): if rid not in self.rid_to_state: @@ -544,7 +550,6 @@ class GrpcRequestManager: continue # Update metrics - now = time.time() if state.first_token_time == 0.0: state.first_token_time = now state.last_time = now @@ -638,7 +643,8 @@ class GrpcRequestManager: if output_data["token_ids"]: state.output_ids.extend(output_data["token_ids"]) - await state.out_queue.put(output_data) + # Add queue.put() to parallel task list + put_tasks.append(state.out_queue.put(output_data)) # Handle completion if output_data["finished"]: @@ -648,12 +654,16 @@ class GrpcRequestManager: state.event.set() # Remove from tracking after a delay - async def cleanup(): + async def cleanup(request_id): await asyncio.sleep(5.0) - if rid in self.rid_to_state: - del self.rid_to_state[rid] + if request_id in self.rid_to_state: + del self.rid_to_state[request_id] - asyncio.create_task(cleanup()) + cleanup_tasks.append(asyncio.create_task(cleanup(rid))) + + # Execute all queue.put() operations in parallel + if put_tasks: + await asyncio.gather(*put_tasks, return_exceptions=True) async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput): """Handle batch embedding output from scheduler.""" diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 2284b789d..a08379471 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -10,10 +10,7 @@ use std::{ use async_trait::async_trait; use futures; use serde_json; -use tokio::{ - sync::{Mutex, RwLock}, - time, -}; +use tokio::{sync::RwLock, time}; use super::{CircuitBreaker, WorkerError, WorkerResult}; use crate::{ @@ -232,7 +229,7 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Get or create a gRPC client for this worker /// Returns None for HTTP workers, Some(client) for gRPC workers - async fn get_grpc_client(&self) -> WorkerResult>>>; + async fn get_grpc_client(&self) -> WorkerResult>>; /// Reset the gRPC client connection (for reconnection scenarios) /// No-op for HTTP workers @@ -367,7 +364,7 @@ pub struct BasicWorker { pub consecutive_successes: Arc, pub circuit_breaker: CircuitBreaker, /// Lazily initialized gRPC client for gRPC workers - pub grpc_client: Arc>>>>, + pub grpc_client: Arc>>>, } impl fmt::Debug for BasicWorker { @@ -505,7 +502,7 @@ impl Worker for BasicWorker { &self.circuit_breaker } - async fn get_grpc_client(&self) -> WorkerResult>>> { + async fn get_grpc_client(&self) -> WorkerResult>> { match self.metadata.connection_mode { ConnectionMode::Http => Ok(None), ConnectionMode::Grpc { .. } => { @@ -528,7 +525,7 @@ impl Worker for BasicWorker { ); match SglangSchedulerClient::connect(&self.metadata.url).await { Ok(client) => { - let client_arc = Arc::new(Mutex::new(client)); + let client_arc = Arc::new(client); *client_guard = Some(client_arc.clone()); tracing::info!( "Successfully connected gRPC client for worker: {}", @@ -577,8 +574,7 @@ impl Worker for BasicWorker { return Ok(false); }; - let client = grpc_client.lock().await; - match time::timeout(timeout, client.health_check()).await { + match time::timeout(timeout, grpc_client.health_check()).await { Ok(Ok(resp)) => { tracing::debug!( "gRPC health OK for {}: healthy={}", @@ -749,7 +745,7 @@ impl Worker for DPAwareWorker { format!("{}{}", self.base_url, route) } - async fn get_grpc_client(&self) -> WorkerResult>>> { + async fn get_grpc_client(&self) -> WorkerResult>> { self.base_worker.get_grpc_client().await } diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index fd30c4bd8..ebd9f7d16 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -104,7 +104,7 @@ impl BasicWorkerBuilder { Arc, }; - use tokio::sync::{Mutex, RwLock}; + use tokio::sync::RwLock; let bootstrap_host = match url::Url::parse(&self.url) { Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(), @@ -145,9 +145,7 @@ impl BasicWorkerBuilder { bootstrap_port, }; - let grpc_client = Arc::new(RwLock::new( - self.grpc_client.map(|client| Arc::new(Mutex::new(client))), - )); + let grpc_client = Arc::new(RwLock::new(self.grpc_client.map(Arc::new))); BasicWorker { metadata, diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index e2c06b2a8..e39c9ff9e 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -42,8 +42,7 @@ pub async fn get_grpc_client_from_worker( .map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))? .ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?; - let client = client_arc.lock().await.clone(); - Ok(client) + Ok((*client_arc).clone()) } /// Process tool call arguments in messages