[router] add reasoning and tool parser argument in router (#11290)
This commit is contained in:
@@ -86,6 +86,9 @@ class RouterArgs:
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
# Parser configuration
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_call_parser: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@@ -446,6 +449,18 @@ class RouterArgs:
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}reasoning-parser",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the parser for reasoning models (e.g., deepseek-r1, qwen3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tool-call-parser",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the parser for handling tool-call interactions",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
|
||||
@@ -73,6 +73,10 @@ pub struct RouterConfig {
|
||||
/// Oracle history backend configuration (required when `history_backend` = "oracle")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oracle: Option<OracleConfig>,
|
||||
/// Parser for reasoning models (e.g., deepseek-r1, qwen3)
|
||||
pub reasoning_parser: Option<String>,
|
||||
/// Parser for handling tool-call interactions
|
||||
pub tool_call_parser: Option<String>,
|
||||
}
|
||||
|
||||
fn default_history_backend() -> HistoryBackend {
|
||||
@@ -448,6 +452,8 @@ impl Default for RouterConfig {
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -990,6 +996,8 @@ mod tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -1055,6 +1063,8 @@ mod tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -1116,6 +1126,8 @@ mod tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -90,6 +90,8 @@ struct Router {
|
||||
connection_mode: config::ConnectionMode,
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
reasoning_parser: Option<String>,
|
||||
tool_call_parser: Option<String>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -216,6 +218,8 @@ impl Router {
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend: config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: self.reasoning_parser.clone(),
|
||||
tool_call_parser: self.tool_call_parser.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -280,6 +284,8 @@ impl Router {
|
||||
rate_limit_tokens_per_second = None,
|
||||
model_path = None,
|
||||
tokenizer_path = None,
|
||||
reasoning_parser = None,
|
||||
tool_call_parser = None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
@@ -339,6 +345,8 @@ impl Router {
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
reasoning_parser: Option<String>,
|
||||
tool_call_parser: Option<String>,
|
||||
) -> PyResult<Self> {
|
||||
let mut all_urls = worker_urls.clone();
|
||||
|
||||
@@ -412,6 +420,8 @@ impl Router {
|
||||
connection_mode,
|
||||
model_path,
|
||||
tokenizer_path,
|
||||
reasoning_parser,
|
||||
tool_call_parser,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -281,6 +281,12 @@ struct CliArgs {
|
||||
|
||||
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
|
||||
oracle_pool_timeout_secs: Option<u64>,
|
||||
|
||||
#[arg(long)]
|
||||
reasoning_parser: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tool_call_parser: Option<String>,
|
||||
}
|
||||
|
||||
enum OracleConnectSource {
|
||||
@@ -557,6 +563,8 @@ impl CliArgs {
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend,
|
||||
oracle,
|
||||
reasoning_parser: self.reasoning_parser.clone(),
|
||||
tool_call_parser: self.tool_call_parser.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ pub struct GrpcPDRouter {
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
}
|
||||
|
||||
impl GrpcPDRouter {
|
||||
@@ -88,6 +90,8 @@ impl GrpcPDRouter {
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
|
||||
configured_tool_parser: ctx.configured_tool_parser.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1179,9 +1183,13 @@ impl GrpcPDRouter {
|
||||
created: u64,
|
||||
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
|
||||
// Get or create parser for this index
|
||||
reasoning_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
|
||||
reasoning_parsers.entry(index).or_insert_with(|| {
|
||||
utils::get_reasoning_parser(
|
||||
&self.reasoning_parser_factory,
|
||||
self.configured_reasoning_parser.as_ref(),
|
||||
model,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||
let (parse_result, in_reasoning) = {
|
||||
@@ -1248,9 +1256,13 @@ impl GrpcPDRouter {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
// Get or create parser for this index
|
||||
tool_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
|
||||
tool_parsers.entry(index).or_insert_with(|| {
|
||||
utils::get_tool_parser(
|
||||
&self.tool_parser_factory,
|
||||
self.configured_tool_parser.as_ref(),
|
||||
model,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(pooled_parser) = tool_parsers.get(&index) {
|
||||
let mut parser = pooled_parser.lock().await;
|
||||
@@ -1737,9 +1749,11 @@ impl GrpcPDRouter {
|
||||
|
||||
// Check if reasoning parsing is enabled and separate_reasoning is requested
|
||||
if original_request.separate_reasoning {
|
||||
let pooled_parser = self
|
||||
.reasoning_parser_factory
|
||||
.get_pooled(&original_request.model);
|
||||
let pooled_parser = utils::get_reasoning_parser(
|
||||
&self.reasoning_parser_factory,
|
||||
self.configured_reasoning_parser.as_ref(),
|
||||
&original_request.model,
|
||||
);
|
||||
|
||||
let mut parser = pooled_parser
|
||||
.lock()
|
||||
@@ -1860,7 +1874,11 @@ impl GrpcPDRouter {
|
||||
history_tool_calls_count: usize,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
// Get pooled parser for this model
|
||||
let pooled_parser = self.tool_parser_factory.get_pooled(model);
|
||||
let pooled_parser = utils::get_tool_parser(
|
||||
&self.tool_parser_factory,
|
||||
self.configured_tool_parser.as_ref(),
|
||||
model,
|
||||
);
|
||||
|
||||
// Check format detection first
|
||||
let can_parse = {
|
||||
|
||||
@@ -53,6 +53,8 @@ pub struct GrpcRouter {
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
@@ -87,6 +89,8 @@ impl GrpcRouter {
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
|
||||
configured_tool_parser: ctx.configured_tool_parser.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -301,7 +305,11 @@ impl GrpcRouter {
|
||||
history_tool_calls_count: usize,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
// Get pooled parser for this model
|
||||
let pooled_parser = self.tool_parser_factory.get_pooled(model);
|
||||
let pooled_parser = utils::get_tool_parser(
|
||||
&self.tool_parser_factory,
|
||||
self.configured_tool_parser.as_ref(),
|
||||
model,
|
||||
);
|
||||
|
||||
// Check format detection first
|
||||
let can_parse = {
|
||||
@@ -496,9 +504,13 @@ impl GrpcRouter {
|
||||
created: u64,
|
||||
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
|
||||
// Get or create parser for this index
|
||||
reasoning_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
|
||||
reasoning_parsers.entry(index).or_insert_with(|| {
|
||||
utils::get_reasoning_parser(
|
||||
&self.reasoning_parser_factory,
|
||||
self.configured_reasoning_parser.as_ref(),
|
||||
model,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||
let (parse_result, in_reasoning) = {
|
||||
@@ -569,9 +581,13 @@ impl GrpcRouter {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
// Get or create parser for this index
|
||||
tool_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
|
||||
tool_parsers.entry(index).or_insert_with(|| {
|
||||
utils::get_tool_parser(
|
||||
&self.tool_parser_factory,
|
||||
self.configured_tool_parser.as_ref(),
|
||||
model,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(pooled_parser) = tool_parsers.get(&index) {
|
||||
let mut parser = pooled_parser.lock().await;
|
||||
@@ -1615,9 +1631,11 @@ impl GrpcRouter {
|
||||
|
||||
// Check if reasoning parsing is enabled and separate_reasoning is requested
|
||||
if original_request.separate_reasoning {
|
||||
let pooled_parser = self
|
||||
.reasoning_parser_factory
|
||||
.get_pooled(&original_request.model);
|
||||
let pooled_parser = utils::get_reasoning_parser(
|
||||
&self.reasoning_parser_factory,
|
||||
self.configured_reasoning_parser.as_ref(),
|
||||
&original_request.model,
|
||||
);
|
||||
|
||||
let mut parser = pooled_parser
|
||||
.lock()
|
||||
|
||||
@@ -641,6 +641,64 @@ pub fn generate_tool_call_id(
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the appropriate reasoning parser for a model
|
||||
///
|
||||
/// If a parser name is explicitly configured, use that parser.
|
||||
/// Otherwise, auto-detect based on the model name.
|
||||
pub fn get_reasoning_parser(
|
||||
reasoning_parser_factory: &crate::reasoning_parser::ReasoningParserFactory,
|
||||
configured_parser: Option<&String>,
|
||||
model: &str,
|
||||
) -> crate::reasoning_parser::PooledParser {
|
||||
use tracing::warn;
|
||||
|
||||
if let Some(parser_name) = configured_parser {
|
||||
// Use configured parser if specified
|
||||
reasoning_parser_factory
|
||||
.registry()
|
||||
.get_pooled_parser(parser_name)
|
||||
.unwrap_or_else(|| {
|
||||
warn!(
|
||||
"Configured reasoning parser '{}' not found, falling back to model-based selection",
|
||||
parser_name
|
||||
);
|
||||
reasoning_parser_factory.get_pooled(model)
|
||||
})
|
||||
} else {
|
||||
// Auto-detect based on model
|
||||
reasoning_parser_factory.get_pooled(model)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the appropriate tool parser for a model
|
||||
///
|
||||
/// If a parser name is explicitly configured, use that parser.
|
||||
/// Otherwise, auto-detect based on the model name.
|
||||
pub fn get_tool_parser(
|
||||
tool_parser_factory: &crate::tool_parser::ToolParserFactory,
|
||||
configured_parser: Option<&String>,
|
||||
model: &str,
|
||||
) -> crate::tool_parser::PooledToolParser {
|
||||
use tracing::warn;
|
||||
|
||||
if let Some(parser_name) = configured_parser {
|
||||
// Use configured parser if specified
|
||||
tool_parser_factory
|
||||
.registry()
|
||||
.get_pooled_parser(parser_name)
|
||||
.unwrap_or_else(|| {
|
||||
warn!(
|
||||
"Configured tool parser '{}' not found, falling back to model-based selection",
|
||||
parser_name
|
||||
);
|
||||
tool_parser_factory.get_pooled(model)
|
||||
})
|
||||
} else {
|
||||
// Auto-detect based on model
|
||||
tool_parser_factory.get_pooled(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -52,6 +52,8 @@ pub struct AppContext {
|
||||
pub router_manager: Option<Arc<RouterManager>>,
|
||||
pub response_storage: SharedResponseStorage,
|
||||
pub load_monitor: Option<Arc<LoadMonitor>>,
|
||||
pub configured_reasoning_parser: Option<String>,
|
||||
pub configured_tool_parser: Option<String>,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -115,6 +117,9 @@ impl AppContext {
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
let configured_reasoning_parser = router_config.reasoning_parser.clone();
|
||||
let configured_tool_parser = router_config.tool_call_parser.clone();
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
@@ -127,6 +132,8 @@ impl AppContext {
|
||||
router_manager,
|
||||
response_storage,
|
||||
load_monitor,
|
||||
configured_reasoning_parser,
|
||||
configured_tool_parser,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,6 +543,8 @@ mod tests {
|
||||
router_manager: None,
|
||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||
load_monitor: None,
|
||||
configured_reasoning_parser: None,
|
||||
configured_tool_parser: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,8 @@ impl TestContext {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1396,6 +1398,8 @@ mod error_tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1755,6 +1759,8 @@ mod pd_mode_tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1915,6 +1921,8 @@ mod request_id_tests {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -76,6 +76,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
// Create router and context
|
||||
@@ -508,6 +510,8 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
@@ -686,6 +690,8 @@ async fn test_max_tool_calls_limit() {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
@@ -826,6 +832,8 @@ async fn setup_streaming_mcp_test() -> (
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
|
||||
@@ -826,6 +826,8 @@ fn oracle_config_validation_requires_config_when_enabled() {
|
||||
},
|
||||
history_backend: HistoryBackend::Oracle,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
@@ -195,6 +195,8 @@ mod test_pd_routing {
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
reasoning_parser: None,
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let app_context =
|
||||
|
||||
Reference in New Issue
Block a user