[router][grpc] Support parallel queue puts in grpc_request_manager and remove mutex for grpc_client (#11798)
This commit is contained in:
@@ -443,10 +443,11 @@ class GrpcRequestManager:
|
|||||||
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||||
self.last_receive_tstamp = time.time()
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
# Check for pause
|
# Check for pause (optimized: check flag before acquiring lock)
|
||||||
async with self.is_pause_cond:
|
if self.is_pause:
|
||||||
while self.is_pause:
|
async with self.is_pause_cond:
|
||||||
await self.is_pause_cond.wait()
|
while self.is_pause:
|
||||||
|
await self.is_pause_cond.wait()
|
||||||
|
|
||||||
# Handle different output types
|
# Handle different output types
|
||||||
if isinstance(recv_obj, BatchTokenIDOutput):
|
if isinstance(recv_obj, BatchTokenIDOutput):
|
||||||
@@ -531,6 +532,11 @@ class GrpcRequestManager:
|
|||||||
|
|
||||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
||||||
"""Handle batch generation output from scheduler."""
|
"""Handle batch generation output from scheduler."""
|
||||||
|
# Collect all queue.put() tasks for parallel execution
|
||||||
|
put_tasks = []
|
||||||
|
cleanup_tasks = []
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
# Process each request in the batch
|
# Process each request in the batch
|
||||||
for i, rid in enumerate(batch_out.rids):
|
for i, rid in enumerate(batch_out.rids):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
@@ -544,7 +550,6 @@ class GrpcRequestManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
now = time.time()
|
|
||||||
if state.first_token_time == 0.0:
|
if state.first_token_time == 0.0:
|
||||||
state.first_token_time = now
|
state.first_token_time = now
|
||||||
state.last_time = now
|
state.last_time = now
|
||||||
@@ -638,7 +643,8 @@ class GrpcRequestManager:
|
|||||||
if output_data["token_ids"]:
|
if output_data["token_ids"]:
|
||||||
state.output_ids.extend(output_data["token_ids"])
|
state.output_ids.extend(output_data["token_ids"])
|
||||||
|
|
||||||
await state.out_queue.put(output_data)
|
# Add queue.put() to parallel task list
|
||||||
|
put_tasks.append(state.out_queue.put(output_data))
|
||||||
|
|
||||||
# Handle completion
|
# Handle completion
|
||||||
if output_data["finished"]:
|
if output_data["finished"]:
|
||||||
@@ -648,12 +654,16 @@ class GrpcRequestManager:
|
|||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
# Remove from tracking after a delay
|
# Remove from tracking after a delay
|
||||||
async def cleanup():
|
async def cleanup(request_id):
|
||||||
await asyncio.sleep(5.0)
|
await asyncio.sleep(5.0)
|
||||||
if rid in self.rid_to_state:
|
if request_id in self.rid_to_state:
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[request_id]
|
||||||
|
|
||||||
asyncio.create_task(cleanup())
|
cleanup_tasks.append(asyncio.create_task(cleanup(rid)))
|
||||||
|
|
||||||
|
# Execute all queue.put() operations in parallel
|
||||||
|
if put_tasks:
|
||||||
|
await asyncio.gather(*put_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
||||||
"""Handle batch embedding output from scheduler."""
|
"""Handle batch embedding output from scheduler."""
|
||||||
|
|||||||
@@ -10,10 +10,7 @@ use std::{
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures;
|
use futures;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use tokio::{
|
use tokio::{sync::RwLock, time};
|
||||||
sync::{Mutex, RwLock},
|
|
||||||
time,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -232,7 +229,7 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
|||||||
|
|
||||||
/// Get or create a gRPC client for this worker
|
/// Get or create a gRPC client for this worker
|
||||||
/// Returns None for HTTP workers, Some(client) for gRPC workers
|
/// Returns None for HTTP workers, Some(client) for gRPC workers
|
||||||
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>>;
|
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>>;
|
||||||
|
|
||||||
/// Reset the gRPC client connection (for reconnection scenarios)
|
/// Reset the gRPC client connection (for reconnection scenarios)
|
||||||
/// No-op for HTTP workers
|
/// No-op for HTTP workers
|
||||||
@@ -367,7 +364,7 @@ pub struct BasicWorker {
|
|||||||
pub consecutive_successes: Arc<AtomicUsize>,
|
pub consecutive_successes: Arc<AtomicUsize>,
|
||||||
pub circuit_breaker: CircuitBreaker,
|
pub circuit_breaker: CircuitBreaker,
|
||||||
/// Lazily initialized gRPC client for gRPC workers
|
/// Lazily initialized gRPC client for gRPC workers
|
||||||
pub grpc_client: Arc<RwLock<Option<Arc<Mutex<SglangSchedulerClient>>>>>,
|
pub grpc_client: Arc<RwLock<Option<Arc<SglangSchedulerClient>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for BasicWorker {
|
impl fmt::Debug for BasicWorker {
|
||||||
@@ -505,7 +502,7 @@ impl Worker for BasicWorker {
|
|||||||
&self.circuit_breaker
|
&self.circuit_breaker
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>> {
|
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>> {
|
||||||
match self.metadata.connection_mode {
|
match self.metadata.connection_mode {
|
||||||
ConnectionMode::Http => Ok(None),
|
ConnectionMode::Http => Ok(None),
|
||||||
ConnectionMode::Grpc { .. } => {
|
ConnectionMode::Grpc { .. } => {
|
||||||
@@ -528,7 +525,7 @@ impl Worker for BasicWorker {
|
|||||||
);
|
);
|
||||||
match SglangSchedulerClient::connect(&self.metadata.url).await {
|
match SglangSchedulerClient::connect(&self.metadata.url).await {
|
||||||
Ok(client) => {
|
Ok(client) => {
|
||||||
let client_arc = Arc::new(Mutex::new(client));
|
let client_arc = Arc::new(client);
|
||||||
*client_guard = Some(client_arc.clone());
|
*client_guard = Some(client_arc.clone());
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"Successfully connected gRPC client for worker: {}",
|
"Successfully connected gRPC client for worker: {}",
|
||||||
@@ -577,8 +574,7 @@ impl Worker for BasicWorker {
|
|||||||
return Ok(false);
|
return Ok(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = grpc_client.lock().await;
|
match time::timeout(timeout, grpc_client.health_check()).await {
|
||||||
match time::timeout(timeout, client.health_check()).await {
|
|
||||||
Ok(Ok(resp)) => {
|
Ok(Ok(resp)) => {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"gRPC health OK for {}: healthy={}",
|
"gRPC health OK for {}: healthy={}",
|
||||||
@@ -749,7 +745,7 @@ impl Worker for DPAwareWorker {
|
|||||||
format!("{}{}", self.base_url, route)
|
format!("{}{}", self.base_url, route)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>> {
|
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>> {
|
||||||
self.base_worker.get_grpc_client().await
|
self.base_worker.get_grpc_client().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ impl BasicWorkerBuilder {
|
|||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use tokio::sync::{Mutex, RwLock};
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
let bootstrap_host = match url::Url::parse(&self.url) {
|
let bootstrap_host = match url::Url::parse(&self.url) {
|
||||||
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
||||||
@@ -145,9 +145,7 @@ impl BasicWorkerBuilder {
|
|||||||
bootstrap_port,
|
bootstrap_port,
|
||||||
};
|
};
|
||||||
|
|
||||||
let grpc_client = Arc::new(RwLock::new(
|
let grpc_client = Arc::new(RwLock::new(self.grpc_client.map(Arc::new)));
|
||||||
self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
|
|
||||||
));
|
|
||||||
|
|
||||||
BasicWorker {
|
BasicWorker {
|
||||||
metadata,
|
metadata,
|
||||||
|
|||||||
@@ -42,8 +42,7 @@ pub async fn get_grpc_client_from_worker(
|
|||||||
.map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))?
|
.map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))?
|
||||||
.ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?;
|
.ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?;
|
||||||
|
|
||||||
let client = client_arc.lock().await.clone();
|
Ok((*client_arc).clone())
|
||||||
Ok(client)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process tool call arguments in messages
|
/// Process tool call arguments in messages
|
||||||
|
|||||||
Reference in New Issue
Block a user