diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 4bb32df40..249d71592 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -247,6 +247,20 @@ pub enum ConnectionMode { }, } +impl ConnectionMode { + /// Check if this connection mode matches another, with special handling for gRPC + /// This allows matching any gRPC connection regardless of port when comparing + /// Grpc { port: None } as a wildcard + pub fn matches(&self, filter: &ConnectionMode) -> bool { + match (self, filter) { + (ConnectionMode::Http, ConnectionMode::Http) => true, + (ConnectionMode::Grpc { .. }, ConnectionMode::Grpc { port: None }) => true, + (ConnectionMode::Grpc { port: p1 }, ConnectionMode::Grpc { port: p2 }) => p1 == p2, + _ => false, + } + } +} + impl fmt::Display for ConnectionMode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/sgl-router/src/core/worker_registry.rs b/sgl-router/src/core/worker_registry.rs index e4d65a491..0e58681e6 100644 --- a/sgl-router/src/core/worker_registry.rs +++ b/sgl-router/src/core/worker_registry.rs @@ -308,9 +308,9 @@ impl WorkerRegistry { } } - // Check connection_mode if specified + // Check connection_mode if specified (using matches for flexible gRPC matching) if let Some(ref conn) = connection_mode { - if w.connection_mode() != *conn { + if !w.connection_mode().matches(conn) { return false; } } diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index dc595f673..9b21a8281 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -340,45 +340,32 @@ impl WorkerSelectionStage { model_id: Option<&str>, text: Option<&str>, ) -> Option<(Arc, Arc)> { - // Get prefill workers - use None for WorkerType filter to get all types, - // then filter manually (since Prefill is a struct variant) let all_workers = self.worker_registry.get_workers_filtered( model_id, - None, // Get all types - Some(ConnectionMode::Grpc { port: None }), + None, + Some(ConnectionMode::Grpc { port: None }), // Match any gRPC worker false, ); - let prefill_workers: Vec<_> = all_workers - .iter() - .filter(|w| matches!(w.metadata().worker_type, WorkerType::Prefill { .. })) - .cloned() - .collect(); - - let available_prefill: Vec<_> = prefill_workers - .iter() - .filter(|w| w.is_available()) - .cloned() - .collect(); + let (available_prefill, available_decode): (Vec<_>, Vec<_>) = + all_workers + .into_iter() + .fold((Vec::new(), Vec::new()), |mut acc, w| { + if w.is_available() { + match w.metadata().worker_type { + WorkerType::Prefill { .. } => acc.0.push(w), + WorkerType::Decode => acc.1.push(w), + _ => {} + } + } + acc + }); if available_prefill.is_empty() { warn!("No available prefill workers"); return None; } - // Get decode workers from the same all_workers list - let decode_workers: Vec<_> = all_workers - .iter() - .filter(|w| matches!(w.metadata().worker_type, WorkerType::Decode)) - .cloned() - .collect(); - - let available_decode: Vec<_> = decode_workers - .iter() - .filter(|w| w.is_available()) - .cloned() - .collect(); - if available_decode.is_empty() { warn!("No available decode workers"); return None;