[router] add grpc pd and regular router init (#9893)
This commit is contained in:
@@ -7,6 +7,9 @@ use std::collections::HashMap;
|
||||
pub struct RouterConfig {
|
||||
/// Routing mode configuration
|
||||
pub mode: RoutingMode,
|
||||
/// Worker connection mode
|
||||
#[serde(default)]
|
||||
pub connection_mode: ConnectionMode,
|
||||
/// Policy configuration
|
||||
pub policy: PolicyConfig,
|
||||
/// Server host address
|
||||
@@ -60,6 +63,20 @@ pub struct RouterConfig {
|
||||
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
|
||||
#[serde(default)]
|
||||
pub enable_igw: bool,
|
||||
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path)
|
||||
pub model_path: Option<String>,
|
||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||
pub tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ConnectionMode {
|
||||
#[default]
|
||||
#[serde(rename = "http")]
|
||||
Http,
|
||||
#[serde(rename = "grpc")]
|
||||
Grpc,
|
||||
}
|
||||
|
||||
/// Routing mode configuration
|
||||
@@ -336,6 +353,9 @@ impl Default for RouterConfig {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -478,6 +498,9 @@ mod tests {
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -914,6 +937,9 @@ mod tests {
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -974,6 +1000,9 @@ mod tests {
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -1030,6 +1059,9 @@ mod tests {
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -349,6 +349,16 @@ impl ConfigValidator {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate gRPC connection mode requires tokenizer configuration
|
||||
if config.connection_mode == ConnectionMode::Grpc
|
||||
&& config.tokenizer_path.is_none()
|
||||
&& config.model_path.is_none()
|
||||
{
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
// No mode/policy restrictions needed anymore
|
||||
|
||||
@@ -419,11 +429,14 @@ impl ConfigValidator {
|
||||
});
|
||||
}
|
||||
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
if !url.starts_with("http://")
|
||||
&& !url.starts_with("https://")
|
||||
&& !url.starts_with("grpc://")
|
||||
{
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL must start with http:// or https://".to_string(),
|
||||
reason: "URL must start with http://, https://, or grpc://".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -684,4 +697,60 @@ mod tests {
|
||||
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_requires_tokenizer() {
|
||||
// Test that gRPC connection mode requires tokenizer configuration
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Set connection mode to gRPC without tokenizer config
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = None;
|
||||
config.model_path = None;
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("gRPC connection mode requires"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_model_path() {
|
||||
// Test that gRPC works with model_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_tokenizer_path() {
|
||||
// Test that gRPC works with tokenizer_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user