[router] add py binding and readme for openai router and history backend (#11453)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import BackendType, HistoryBackendType, PolicyType, PyOracleConfig
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
@@ -18,6 +18,39 @@ def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def backend_from_str(backend_str: Optional[str]) -> BackendType:
|
||||
"""Convert backend string to BackendType enum."""
|
||||
if isinstance(backend_str, BackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return BackendType.Sglang
|
||||
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower not in backend_map:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
|
||||
)
|
||||
return backend_map[backend_lower]
|
||||
|
||||
|
||||
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
|
||||
"""Convert history backend string to HistoryBackendType enum."""
|
||||
if isinstance(backend_str, HistoryBackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return HistoryBackendType.Memory
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower == "memory":
|
||||
return HistoryBackendType.Memory
|
||||
elif backend_lower == "none":
|
||||
# Use getattr to access 'None' which is a Python keyword
|
||||
return getattr(HistoryBackendType, "None")
|
||||
elif backend_lower == "oracle":
|
||||
return HistoryBackendType.Oracle
|
||||
else:
|
||||
raise ValueError(f"Unknown history backend: {backend_str}")
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
@@ -119,8 +152,49 @@ class Router:
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# remove mini_lb parameter
|
||||
args_dict.pop("mini_lb")
|
||||
# Convert backend
|
||||
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
|
||||
|
||||
# Convert history_backend to enum first
|
||||
history_backend_raw = args_dict.get("history_backend", "memory")
|
||||
history_backend = history_backend_from_str(history_backend_raw)
|
||||
|
||||
# Convert Oracle config if needed
|
||||
oracle_config = None
|
||||
if history_backend == HistoryBackendType.Oracle:
|
||||
# Prioritize TNS alias over connect descriptor
|
||||
tns_alias = args_dict.get("oracle_tns_alias")
|
||||
connect_descriptor = args_dict.get("oracle_connect_descriptor")
|
||||
|
||||
# Use TNS alias if provided, otherwise use connect descriptor
|
||||
final_descriptor = tns_alias if tns_alias else connect_descriptor
|
||||
|
||||
oracle_config = PyOracleConfig(
|
||||
password=args_dict.get("oracle_password"),
|
||||
username=args_dict.get("oracle_username"),
|
||||
connect_descriptor=final_descriptor,
|
||||
wallet_path=args_dict.get("oracle_wallet_path"),
|
||||
pool_min=args_dict.get("oracle_pool_min", 1),
|
||||
pool_max=args_dict.get("oracle_pool_max", 16),
|
||||
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
|
||||
)
|
||||
args_dict["oracle_config"] = oracle_config
|
||||
args_dict["history_backend"] = history_backend
|
||||
|
||||
# Remove fields that shouldn't be passed to Rust Router constructor
|
||||
fields_to_remove = [
|
||||
"mini_lb",
|
||||
"oracle_wallet_path",
|
||||
"oracle_tns_alias",
|
||||
"oracle_connect_descriptor",
|
||||
"oracle_username",
|
||||
"oracle_password",
|
||||
"oracle_pool_min",
|
||||
"oracle_pool_max",
|
||||
"oracle_pool_timeout_secs",
|
||||
]
|
||||
for field in fields_to_remove:
|
||||
args_dict.pop(field, None)
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -88,6 +89,18 @@ class RouterArgs:
|
||||
chat_template: Optional[str] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_call_parser: Optional[str] = None
|
||||
# Backend selection
|
||||
backend: str = "sglang"
|
||||
# History backend configuration
|
||||
history_backend: str = "memory"
|
||||
oracle_wallet_path: Optional[str] = None
|
||||
oracle_tns_alias: Optional[str] = None
|
||||
oracle_connect_descriptor: Optional[str] = None
|
||||
oracle_username: Optional[str] = None
|
||||
oracle_password: Optional[str] = None
|
||||
oracle_pool_min: int = 1
|
||||
oracle_pool_max: int = 16
|
||||
oracle_pool_timeout_secs: int = 30
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@@ -466,6 +479,73 @@ class RouterArgs:
|
||||
default=None,
|
||||
help="Specify the parser for handling tool-call interactions",
|
||||
)
|
||||
# Backend selection
|
||||
parser.add_argument(
|
||||
f"--{prefix}backend",
|
||||
type=str,
|
||||
default=RouterArgs.backend,
|
||||
choices=["sglang", "openai"],
|
||||
help="Backend runtime to use (default: sglang)",
|
||||
)
|
||||
# History backend configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}history-backend",
|
||||
type=str,
|
||||
default=RouterArgs.history_backend,
|
||||
choices=["memory", "none", "oracle"],
|
||||
help="History storage backend for conversations and responses (default: memory)",
|
||||
)
|
||||
# Oracle configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-wallet-path",
|
||||
type=str,
|
||||
default=os.getenv("ATP_WALLET_PATH"),
|
||||
help="Path to Oracle ATP wallet directory (env: ATP_WALLET_PATH)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-tns-alias",
|
||||
type=str,
|
||||
default=os.getenv("ATP_TNS_ALIAS"),
|
||||
help="Oracle TNS alias from tnsnames.ora (env: ATP_TNS_ALIAS).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-connect-descriptor",
|
||||
type=str,
|
||||
default=os.getenv("ATP_DSN"),
|
||||
help="Oracle connection descriptor/DSN (full connection string) (env: ATP_DSN)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-username",
|
||||
type=str,
|
||||
default=os.getenv("ATP_USER"),
|
||||
help="Oracle database username (env: ATP_USER)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-password",
|
||||
type=str,
|
||||
default=os.getenv("ATP_PASSWORD"),
|
||||
help="Oracle database password (env: ATP_PASSWORD)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-pool-min",
|
||||
type=int,
|
||||
default=int(os.getenv("ATP_POOL_MIN", RouterArgs.oracle_pool_min)),
|
||||
help="Minimum Oracle connection pool size (default: 1, env: ATP_POOL_MIN)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-pool-max",
|
||||
type=int,
|
||||
default=int(os.getenv("ATP_POOL_MAX", RouterArgs.oracle_pool_max)),
|
||||
help="Maximum Oracle connection pool size (default: 16, env: ATP_POOL_MAX)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}oracle-pool-timeout-secs",
|
||||
type=int,
|
||||
default=int(
|
||||
os.getenv("ATP_POOL_TIMEOUT_SECS", RouterArgs.oracle_pool_timeout_secs)
|
||||
),
|
||||
help="Oracle connection pool timeout in seconds (default: 30, env: ATP_POOL_TIMEOUT_SECS)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
|
||||
Reference in New Issue
Block a user