[router] add py binding and readme for openai router and history backend (#11453)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Keyang Ru
2025-10-14 09:42:34 -07:00
committed by GitHub
parent 5ea96ac7cc
commit eb8cac6fe2
8 changed files with 488 additions and 25 deletions

View File

@@ -1,7 +1,7 @@
from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router_rs import PolicyType
from sglang_router_rs import BackendType, HistoryBackendType, PolicyType, PyOracleConfig
from sglang_router_rs import Router as _Router
@@ -18,6 +18,39 @@ def policy_from_str(policy_str: Optional[str]) -> PolicyType:
return policy_map[policy_str]
def backend_from_str(backend_str: Optional[str]) -> BackendType:
"""Convert backend string to BackendType enum."""
if isinstance(backend_str, BackendType):
return backend_str
if backend_str is None:
return BackendType.Sglang
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
backend_lower = backend_str.lower()
if backend_lower not in backend_map:
raise ValueError(
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
)
return backend_map[backend_lower]
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
"""Convert history backend string to HistoryBackendType enum."""
if isinstance(backend_str, HistoryBackendType):
return backend_str
if backend_str is None:
return HistoryBackendType.Memory
backend_lower = backend_str.lower()
if backend_lower == "memory":
return HistoryBackendType.Memory
elif backend_lower == "none":
# Use getattr to access 'None' which is a Python keyword
return getattr(HistoryBackendType, "None")
elif backend_lower == "oracle":
return HistoryBackendType.Oracle
else:
raise ValueError(f"Unknown history backend: {backend_str}")
class Router:
"""
A high-performance router for distributing requests across worker nodes.
@@ -119,8 +152,49 @@ class Router:
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# remove mini_lb parameter
args_dict.pop("mini_lb")
# Convert backend
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
# Convert history_backend to enum first
history_backend_raw = args_dict.get("history_backend", "memory")
history_backend = history_backend_from_str(history_backend_raw)
# Convert Oracle config if needed
oracle_config = None
if history_backend == HistoryBackendType.Oracle:
# Prioritize TNS alias over connect descriptor
tns_alias = args_dict.get("oracle_tns_alias")
connect_descriptor = args_dict.get("oracle_connect_descriptor")
# Use TNS alias if provided, otherwise use connect descriptor
final_descriptor = tns_alias if tns_alias else connect_descriptor
oracle_config = PyOracleConfig(
password=args_dict.get("oracle_password"),
username=args_dict.get("oracle_username"),
connect_descriptor=final_descriptor,
wallet_path=args_dict.get("oracle_wallet_path"),
pool_min=args_dict.get("oracle_pool_min", 1),
pool_max=args_dict.get("oracle_pool_max", 16),
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
)
args_dict["oracle_config"] = oracle_config
args_dict["history_backend"] = history_backend
# Remove fields that shouldn't be passed to Rust Router constructor
fields_to_remove = [
"mini_lb",
"oracle_wallet_path",
"oracle_tns_alias",
"oracle_connect_descriptor",
"oracle_username",
"oracle_password",
"oracle_pool_min",
"oracle_pool_max",
"oracle_pool_timeout_secs",
]
for field in fields_to_remove:
args_dict.pop(field, None)
return Router(_Router(**args_dict))

View File

@@ -1,6 +1,7 @@
import argparse
import dataclasses
import logging
import os
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -88,6 +89,18 @@ class RouterArgs:
chat_template: Optional[str] = None
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
# Backend selection
backend: str = "sglang"
# History backend configuration
history_backend: str = "memory"
oracle_wallet_path: Optional[str] = None
oracle_tns_alias: Optional[str] = None
oracle_connect_descriptor: Optional[str] = None
oracle_username: Optional[str] = None
oracle_password: Optional[str] = None
oracle_pool_min: int = 1
oracle_pool_max: int = 16
oracle_pool_timeout_secs: int = 30
@staticmethod
def add_cli_args(
@@ -466,6 +479,73 @@ class RouterArgs:
default=None,
help="Specify the parser for handling tool-call interactions",
)
# Backend selection
parser.add_argument(
f"--{prefix}backend",
type=str,
default=RouterArgs.backend,
choices=["sglang", "openai"],
help="Backend runtime to use (default: sglang)",
)
# History backend configuration
parser.add_argument(
f"--{prefix}history-backend",
type=str,
default=RouterArgs.history_backend,
choices=["memory", "none", "oracle"],
help="History storage backend for conversations and responses (default: memory)",
)
# Oracle configuration
parser.add_argument(
f"--{prefix}oracle-wallet-path",
type=str,
default=os.getenv("ATP_WALLET_PATH"),
help="Path to Oracle ATP wallet directory (env: ATP_WALLET_PATH)",
)
parser.add_argument(
f"--{prefix}oracle-tns-alias",
type=str,
default=os.getenv("ATP_TNS_ALIAS"),
help="Oracle TNS alias from tnsnames.ora (env: ATP_TNS_ALIAS).",
)
parser.add_argument(
f"--{prefix}oracle-connect-descriptor",
type=str,
default=os.getenv("ATP_DSN"),
help="Oracle connection descriptor/DSN (full connection string) (env: ATP_DSN)",
)
parser.add_argument(
f"--{prefix}oracle-username",
type=str,
default=os.getenv("ATP_USER"),
help="Oracle database username (env: ATP_USER)",
)
parser.add_argument(
f"--{prefix}oracle-password",
type=str,
default=os.getenv("ATP_PASSWORD"),
help="Oracle database password (env: ATP_PASSWORD)",
)
parser.add_argument(
f"--{prefix}oracle-pool-min",
type=int,
default=int(os.getenv("ATP_POOL_MIN", RouterArgs.oracle_pool_min)),
help="Minimum Oracle connection pool size (default: 1, env: ATP_POOL_MIN)",
)
parser.add_argument(
f"--{prefix}oracle-pool-max",
type=int,
default=int(os.getenv("ATP_POOL_MAX", RouterArgs.oracle_pool_max)),
help="Maximum Oracle connection pool size (default: 16, env: ATP_POOL_MAX)",
)
parser.add_argument(
f"--{prefix}oracle-pool-timeout-secs",
type=int,
default=int(
os.getenv("ATP_POOL_TIMEOUT_SECS", RouterArgs.oracle_pool_timeout_secs)
),
help="Oracle connection pool timeout in seconds (default: 30, env: ATP_POOL_TIMEOUT_SECS)",
)
@classmethod
def from_cli_args(