diff --git a/.github/workflows/pr-test-pd-router.yml b/.github/workflows/pr-test-pd-router.yml new file mode 100644 index 000000000..bb4eaf4c6 --- /dev/null +++ b/.github/workflows/pr-test-pd-router.yml @@ -0,0 +1,249 @@ +name: Test Disaggregation Mode + +on: + push: + branches: [ main ] + paths: + - 'python/sglang/srt/disaggregation/**' + - 'scripts/ci_start_disaggregation_servers.sh' + - 'sgl-router/**' + pull_request: + branches: [ main ] + paths: + - 'python/sglang/srt/disaggregation/**' + - 'scripts/ci_start_disaggregation_servers.sh' + - 'sgl-router/**' + workflow_dispatch: + +concurrency: + group: test-disaggregation-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + pull-requests: write + issues: write + +jobs: + test-disaggregation: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: [h200] + timeout-minutes: 45 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 10 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Setup Rust + run: | + bash scripts/ci_install_rust.sh + + - name: Cache Rust dependencies + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + sgl-router/target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('sgl-router/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Validate environment + run: | + echo "=== System Validation ===" + nvidia-smi + echo "GPU count: $(nvidia-smi -L | wc -l)" + if [ $(nvidia-smi -L | wc -l) -lt 8 ]; then + echo "Error: This test requires at least 8 GPUs" + exit 1 + fi + + echo "=== RDMA Validation ===" + if ! command -v ibv_devices >/dev/null 2>&1; then + echo "Error: InfiniBand tools not found" + exit 1 + fi + + # Check for active IB devices + found_active_device=false + for device in mlx5_{0..11}; do + if ibv_devinfo $device >/dev/null 2>&1; then + state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}') + if [[ "$state" == "PORT_ACTIVE" ]]; then + echo "✓ Found active device: $device" + found_active_device=true + break + fi + fi + done + + if [ "$found_active_device" = false ]; then + echo "Error: No active IB devices found" + echo "Available devices:" + ibv_devices || true + exit 1 + fi + + echo "=== Model Validation ===" + if [ ! -d "/raid/models/meta-llama/Llama-3.1-8B-Instruct" ]; then + echo "Error: Model not found" + ls -la /raid/models/ || echo "No models directory" + exit 1 + fi + echo "✓ Model found" + + - name: Install SGLang dependencies + run: | + echo "Installing SGLang with all extras..." + python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages + python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.4.post1 + + - name: Build and install sgl-router + run: | + source "$HOME/.cargo/env" + echo "Building sgl-router..." + cd sgl-router + cargo build && python3 -m build && pip install --force-reinstall dist/*.whl + + - name: Start disaggregation servers + id: start_servers + run: | + echo "Starting disaggregation servers..." + bash scripts/ci_start_disaggregation_servers.sh & + SERVER_PID=$! + echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT + + echo "Waiting for router to become healthy..." + TIMEOUT=300 + ELAPSED=0 + while [ $ELAPSED -lt $TIMEOUT ]; do + if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then + echo "✓ Router is reachable" + break + fi + if ! ps -p $SERVER_PID > /dev/null; then + echo "Error: Server processes failed to start" + exit 1 + fi + echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)" + sleep 10 + ELAPSED=$((ELAPSED + 10)) + done + + if [ $ELAPSED -ge $TIMEOUT ]; then + echo "Error: Router health check timeout after ${TIMEOUT}s" + exit 1 + fi + + echo "✓ Servers started and healthy (PID: $SERVER_PID)" + + - name: Test API functionality + timeout-minutes: 5 + run: | + BASE_URL="http://127.0.0.9:8000" + + echo "Testing API completions..." + response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"} + ], + "stream": false, + "max_tokens": 100 + }') + + if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then + echo "✓ API test passed" + else + echo "✗ API test failed: $response" + exit 1 + fi + + echo "Testing streaming API..." + stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Count from 1 to 5"} + ], + "stream": true, + "max_tokens": 50 + }') + + if echo "$stream_response" | grep -q "data:"; then + echo "✓ Streaming API test passed" + else + echo "✗ Streaming API test failed" + exit 1 + fi + + - name: Run benchmark test + timeout-minutes: 5 + run: | + echo "Running benchmark test..." + benchmark_output=$(python3 -m sglang.bench_one_batch_server \ + --model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \ + --base-url "http://127.0.0.9:8000" \ + --batch-size 8 \ + --input-len 4096 \ + --output-len 5 \ + --skip-warmup) + + echo "$benchmark_output" + + # Extract metrics from output + latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//') + input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}') + output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}') + + # Validate performance (latency<1.5s, input>20k, output>1k) + command -v bc >/dev/null || (apt-get update && apt-get install -y bc) + + echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s" + + fail="" + (( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) " + (( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) " + (( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) " + + if [ -n "$fail" ]; then + echo "✗ Benchmark failed: $fail" + exit 1 + else + echo "✓ Performance validation passed" + fi + + - name: Cleanup servers + if: always() + run: | + if [ -n "${{ steps.start_servers.outputs.server_pid }}" ]; then + pkill -P ${{ steps.start_servers.outputs.server_pid }} || true + kill ${{ steps.start_servers.outputs.server_pid }} || true + fi + pkill -f "sglang.launch_server" || true + sleep 5 + remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0") + echo "Cleanup completed. Remaining processes: $remaining" diff --git a/python/pyproject.toml b/python/pyproject.toml index ca098f81b..b509618bb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,6 +18,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle [project.optional-dependencies] runtime_common = [ "blobfile==3.0.0", + "build", "compressed-tensors", "datasets", "fastapi", diff --git a/scripts/ci_start_disaggregation_servers.sh b/scripts/ci_start_disaggregation_servers.sh new file mode 100755 index 000000000..f652a4f04 --- /dev/null +++ b/scripts/ci_start_disaggregation_servers.sh @@ -0,0 +1,106 @@ +#!/bin/bash + +MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct" + +# Function to find the first available active IB device +find_active_ib_device() { + for device in mlx5_{0..11}; do + if ibv_devinfo $device >/dev/null 2>&1; then + state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}') + if [[ "$state" == "PORT_ACTIVE" ]]; then + echo "$device" + return 0 + fi + fi + done + echo "No active IB device found" >&2 + return 1 +} + +# Get the first available active IB device +DEVICE=$(find_active_ib_device) +echo "Using IB device: $DEVICE" + +# Launch prefill servers on GPU 0–3 +for i in {0..3}; do + PORT=$((30001 + i)) + BOOTSTRAP_PORT=$((9001 + i)) + HOST="127.0.0.$((i + 1))" + echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)" + CUDA_VISIBLE_DEVICES=$i \ + python3 -m sglang.launch_server \ + --model-path "$MODEL_PATH" \ + --disaggregation-mode prefill \ + --host "$HOST" \ + --port "$PORT" \ + --disaggregation-ib-device "$DEVICE" \ + --disaggregation-bootstrap-port "$BOOTSTRAP_PORT" & +done + +# Launch decode servers on GPU 4–7 +for i in {4..7}; do + PORT=$((30001 + i)) + HOST="127.0.0.$((i + 1))" + echo "Launching DECODE server on GPU $i at $HOST:$PORT" + CUDA_VISIBLE_DEVICES=$i \ + python3 -m sglang.launch_server \ + --model-path "$MODEL_PATH" \ + --disaggregation-mode decode \ + --host "$HOST" \ + --port "$PORT" \ + --disaggregation-ib-device "$DEVICE" \ + --base-gpu-id 0 & +done + +# Wait for disaggregation servers to initialize +echo "Waiting for disaggregation servers to initialize..." + +# Health check with 5-minute timeout +TIMEOUT=300 +START_TIME=$(date +%s) + +echo "Checking health of all 8 servers..." +while true; do + CURRENT_TIME=$(date +%s) + ELAPSED=$((CURRENT_TIME - START_TIME)) + + if [ $ELAPSED -ge $TIMEOUT ]; then + echo "❌ Timeout: Servers did not become healthy within 5 minutes" + exit 1 + fi + + HEALTHY_COUNT=0 + # Check all 8 servers (127.0.0.1-8:30001-30008) + for i in {1..8}; do + if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then + HEALTHY_COUNT=$((HEALTHY_COUNT + 1)) + fi + done + + echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)" + + if [ $HEALTHY_COUNT -eq 8 ]; then + echo "✅ All 8 servers are healthy!" + break + else + sleep 10 # Wait 10 seconds before next check + fi +done + +# Launch the router +echo "Launching router at 127.0.0.9:8000..." +python3 -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy power_of_two \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 9002 \ + --prefill http://127.0.0.3:30003 9003 \ + --prefill http://127.0.0.4:30004 9004 \ + --decode http://127.0.0.5:30005 \ + --decode http://127.0.0.6:30006 \ + --decode http://127.0.0.7:30007 \ + --decode http://127.0.0.8:30008 \ + --host 127.0.0.9 \ + --port 8000 & + +wait # Wait for all background jobs to finish