[router] add py binding and readme for openai router and history backend (#11453)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import BackendType, HistoryBackendType, PolicyType, PyOracleConfig
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
@@ -18,6 +18,39 @@ def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def backend_from_str(backend_str: Optional[str]) -> BackendType:
|
||||
"""Convert backend string to BackendType enum."""
|
||||
if isinstance(backend_str, BackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return BackendType.Sglang
|
||||
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower not in backend_map:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
|
||||
)
|
||||
return backend_map[backend_lower]
|
||||
|
||||
|
||||
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
|
||||
"""Convert history backend string to HistoryBackendType enum."""
|
||||
if isinstance(backend_str, HistoryBackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return HistoryBackendType.Memory
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower == "memory":
|
||||
return HistoryBackendType.Memory
|
||||
elif backend_lower == "none":
|
||||
# Use getattr to access 'None' which is a Python keyword
|
||||
return getattr(HistoryBackendType, "None")
|
||||
elif backend_lower == "oracle":
|
||||
return HistoryBackendType.Oracle
|
||||
else:
|
||||
raise ValueError(f"Unknown history backend: {backend_str}")
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
@@ -119,8 +152,49 @@ class Router:
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# remove mini_lb parameter
|
||||
args_dict.pop("mini_lb")
|
||||
# Convert backend
|
||||
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
|
||||
|
||||
# Convert history_backend to enum first
|
||||
history_backend_raw = args_dict.get("history_backend", "memory")
|
||||
history_backend = history_backend_from_str(history_backend_raw)
|
||||
|
||||
# Convert Oracle config if needed
|
||||
oracle_config = None
|
||||
if history_backend == HistoryBackendType.Oracle:
|
||||
# Prioritize TNS alias over connect descriptor
|
||||
tns_alias = args_dict.get("oracle_tns_alias")
|
||||
connect_descriptor = args_dict.get("oracle_connect_descriptor")
|
||||
|
||||
# Use TNS alias if provided, otherwise use connect descriptor
|
||||
final_descriptor = tns_alias if tns_alias else connect_descriptor
|
||||
|
||||
oracle_config = PyOracleConfig(
|
||||
password=args_dict.get("oracle_password"),
|
||||
username=args_dict.get("oracle_username"),
|
||||
connect_descriptor=final_descriptor,
|
||||
wallet_path=args_dict.get("oracle_wallet_path"),
|
||||
pool_min=args_dict.get("oracle_pool_min", 1),
|
||||
pool_max=args_dict.get("oracle_pool_max", 16),
|
||||
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
|
||||
)
|
||||
args_dict["oracle_config"] = oracle_config
|
||||
args_dict["history_backend"] = history_backend
|
||||
|
||||
# Remove fields that shouldn't be passed to Rust Router constructor
|
||||
fields_to_remove = [
|
||||
"mini_lb",
|
||||
"oracle_wallet_path",
|
||||
"oracle_tns_alias",
|
||||
"oracle_connect_descriptor",
|
||||
"oracle_username",
|
||||
"oracle_password",
|
||||
"oracle_pool_min",
|
||||
"oracle_pool_max",
|
||||
"oracle_pool_timeout_secs",
|
||||
]
|
||||
for field in fields_to_remove:
|
||||
args_dict.pop(field, None)
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user