[router] Add gRPC E2E test suite (#11790)
This commit is contained in:
258
sgl-router/py_test/e2e_grpc/util.py
Normal file
258
sgl-router/py_test/e2e_grpc/util.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Standalone utilities for e2e_grpc tests.
|
||||
|
||||
This module provides all necessary utilities without depending on sglang Python package.
|
||||
Extracted and adapted from:
|
||||
- sglang.srt.utils.kill_process_tree
|
||||
- sglang.srt.utils.hf_transformers_utils.get_tokenizer
|
||||
- sglang.test.test_utils (constants and CustomTestCase)
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import psutil
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"transformers is required for tokenizer utilities. "
|
||||
"Install with: pip install transformers"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Constants
|
||||
# ============================================================================
|
||||
|
||||
# Server and timeout constants
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 20000
|
||||
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
|
||||
|
||||
# File name constants for test output
|
||||
STDOUT_FILENAME = "/tmp/sglang_test_stdout.txt"
|
||||
STDERR_FILENAME = "/tmp/sglang_test_stderr.txt"
|
||||
|
||||
# Model base path - can be overridden via environment variable
|
||||
# By default, use HuggingFace model identifiers (no local path prefix)
|
||||
# Set ROUTER_LOCAL_MODEL_PATH to use local models (e.g., "/home/ubuntu/models")
|
||||
ROUTER_LOCAL_MODEL_PATH = os.environ.get("ROUTER_LOCAL_MODEL_PATH", "")
|
||||
|
||||
|
||||
# Helper function to build model paths
|
||||
def _get_model_path(model_identifier: str) -> str:
|
||||
"""
|
||||
Build model path from base path and model identifier.
|
||||
|
||||
If ROUTER_LOCAL_MODEL_PATH is set, prepend it to the identifier.
|
||||
Otherwise, return the identifier as-is (for HuggingFace download).
|
||||
"""
|
||||
if ROUTER_LOCAL_MODEL_PATH:
|
||||
return os.path.join(ROUTER_LOCAL_MODEL_PATH, model_identifier)
|
||||
return model_identifier
|
||||
|
||||
|
||||
# Model paths used in e2e_grpc tests
|
||||
# These can be either HuggingFace identifiers or local paths (depending on ROUTER_LOCAL_MODEL_PATH)
|
||||
|
||||
# Main test model - Llama 3.1 8B Instruct
|
||||
DEFAULT_MODEL_PATH = _get_model_path("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
# Small models for function calling tests
|
||||
DEFAULT_SMALL_MODEL_PATH = _get_model_path("meta-llama/Llama-3.2-1B-Instruct")
|
||||
|
||||
# Reasoning models
|
||||
DEFAULT_REASONING_MODEL_PATH = _get_model_path(
|
||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
||||
)
|
||||
|
||||
# Thinking-enabled models
|
||||
DEFAULT_ENABLE_THINKING_MODEL_PATH = _get_model_path("Qwen/Qwen3-30B-A3B")
|
||||
|
||||
# Function calling models
|
||||
DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = _get_model_path("Qwen/Qwen2.5-7B-Instruct")
|
||||
DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = _get_model_path(
|
||||
"mistralai/Mistral-7B-Instruct-v0.3"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Process Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
||||
"""
|
||||
Kill the process and all its child processes.
|
||||
|
||||
Args:
|
||||
parent_pid: PID of the parent process
|
||||
include_parent: Whether to kill the parent process itself
|
||||
skip_pid: Optional PID to skip during cleanup
|
||||
"""
|
||||
# Remove sigchld handler to avoid spammy logs
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
||||
|
||||
if parent_pid is None:
|
||||
parent_pid = os.getpid()
|
||||
include_parent = False
|
||||
|
||||
try:
|
||||
itself = psutil.Process(parent_pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
children = itself.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid == skip_pid:
|
||||
continue
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
if include_parent:
|
||||
try:
|
||||
itself.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tokenizer Utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def check_gguf_file(model_path: str) -> bool:
|
||||
"""Check if the model path points to a GGUF file."""
|
||||
if not isinstance(model_path, str):
|
||||
return False
|
||||
return model_path.endswith(".gguf")
|
||||
|
||||
|
||||
def is_remote_url(path: str) -> bool:
|
||||
"""Check if the path is a remote URL."""
|
||||
if not isinstance(path, str):
|
||||
return False
|
||||
return path.startswith("http://") or path.startswith("https://")
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""
|
||||
Gets a tokenizer for the given model name via Huggingface.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name or path of the tokenizer
|
||||
tokenizer_mode: Mode for tokenizer loading ("auto", "slow")
|
||||
trust_remote_code: Whether to trust remote code
|
||||
tokenizer_revision: Specific revision to use
|
||||
**kwargs: Additional arguments passed to AutoTokenizer.from_pretrained
|
||||
|
||||
Returns:
|
||||
Loaded tokenizer instance
|
||||
"""
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
# Handle special model name mapping
|
||||
if tokenizer_name == "mistralai/Devstral-Small-2505":
|
||||
tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
is_gguf = check_gguf_file(tokenizer_name)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = tokenizer_name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
# Note: Removed remote URL handling and local directory download
|
||||
# as they depend on sglang-specific utilities
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
# Handle specific errors
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If you are running a model with "
|
||||
"a custom tokenizer, please set the --trust-remote-code flag."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
print(
|
||||
f"Warning: Using a slow tokenizer. This might cause a performance "
|
||||
f"degradation. Consider using a fast tokenizer instead."
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_tokenizer_from_processor(processor):
|
||||
"""Extract tokenizer from a processor object."""
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
return processor
|
||||
return processor.tokenizer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomTestCase(unittest.TestCase):
|
||||
"""
|
||||
Custom test case base class with retry support.
|
||||
|
||||
This provides automatic test retry functionality based on environment variables.
|
||||
"""
|
||||
|
||||
def _callTestMethod(self, method):
|
||||
"""Override to add retry logic."""
|
||||
max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0"))
|
||||
|
||||
if max_retry == 0:
|
||||
# No retry, just run once
|
||||
return super(CustomTestCase, self)._callTestMethod(method)
|
||||
|
||||
# Retry logic
|
||||
for attempt in range(max_retry + 1):
|
||||
try:
|
||||
return super(CustomTestCase, self)._callTestMethod(method)
|
||||
except Exception as e:
|
||||
if attempt < max_retry:
|
||||
print(
|
||||
f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Last attempt, re-raise the exception
|
||||
raise
|
||||
|
||||
def setUp(self):
|
||||
"""Print test method name at the start of each test."""
|
||||
print(f"[Test Method] {self._testMethodName}", flush=True)
|
||||
Reference in New Issue
Block a user