[router] add grpc pd and regular router init (#9893)

This commit is contained in:
Chang Su
2025-09-01 20:06:15 -07:00
committed by GitHub
parent b5245064f6
commit 9a0cac1be0
14 changed files with 783 additions and 58 deletions

View File

@@ -7,6 +7,9 @@ use std::collections::HashMap;
pub struct RouterConfig {
/// Routing mode configuration
pub mode: RoutingMode,
/// Worker connection mode
#[serde(default)]
pub connection_mode: ConnectionMode,
/// Policy configuration
pub policy: PolicyConfig,
/// Server host address
@@ -60,6 +63,20 @@ pub struct RouterConfig {
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
#[serde(default)]
pub enable_igw: bool,
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path)
pub model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
pub tokenizer_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(tag = "type")]
pub enum ConnectionMode {
#[default]
#[serde(rename = "http")]
Http,
#[serde(rename = "grpc")]
Grpc,
}
/// Routing mode configuration
@@ -336,6 +353,9 @@ impl Default for RouterConfig {
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
}
}
}
@@ -478,6 +498,9 @@ mod tests {
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
let json = serde_json::to_string(&config).unwrap();
@@ -914,6 +937,9 @@ mod tests {
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
assert!(config.mode.is_pd_mode());
@@ -974,6 +1000,9 @@ mod tests {
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
assert!(!config.mode.is_pd_mode());
@@ -1030,6 +1059,9 @@ mod tests {
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
assert!(config.has_service_discovery());

View File

@@ -349,6 +349,16 @@ impl ConfigValidator {
return Ok(());
}
// Validate gRPC connection mode requires tokenizer configuration
if config.connection_mode == ConnectionMode::Grpc
&& config.tokenizer_path.is_none()
&& config.model_path.is_none()
{
return Err(ConfigError::ValidationFailed {
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
});
}
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
@@ -419,11 +429,14 @@ impl ConfigValidator {
});
}
if !url.starts_with("http://") && !url.starts_with("https://") {
if !url.starts_with("http://")
&& !url.starts_with("https://")
&& !url.starts_with("grpc://")
{
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: "URL must start with http:// or https://".to_string(),
reason: "URL must start with http://, https://, or grpc://".to_string(),
});
}
@@ -684,4 +697,60 @@ mod tests {
assert!(e.to_string().contains("prefill requires at least 2"));
}
}
#[test]
fn test_validate_grpc_requires_tokenizer() {
// Test that gRPC connection mode requires tokenizer configuration
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
// Set connection mode to gRPC without tokenizer config
config.connection_mode = ConnectionMode::Grpc;
config.tokenizer_path = None;
config.model_path = None;
let result = ConfigValidator::validate(&config);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("gRPC connection mode requires"));
}
}
#[test]
fn test_validate_grpc_with_model_path() {
// Test that gRPC works with model_path
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_grpc_with_tokenizer_path() {
// Test that gRPC works with tokenizer_path
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
}