sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
60
scripts/ci/amd_ci_exec.sh
Executable file
60
scripts/ci/amd_ci_exec.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Detect GPU family from hostname (e.g., linux-mi35x-gpu-1-xxxxx-runner-zzzzz)
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_FAMILY=""
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_FAMILY="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU family from hostname: ${GPU_FAMILY}"
|
||||
else
|
||||
echo "Warning: could not parse GPU family from '${HOSTNAME_VALUE}'"
|
||||
fi
|
||||
|
||||
WORKDIR="/sglang-checkout/test/srt"
|
||||
declare -A ENV_MAP=(
|
||||
[SGLANG_AMD_CI]=1
|
||||
[SGLANG_IS_IN_CI]=1
|
||||
[SGLANG_USE_AITER]=1
|
||||
)
|
||||
|
||||
# Conditionally add GPU_ARCHS only for mi35x
|
||||
if [[ "${GPU_FAMILY}" == "mi35x" ]]; then
|
||||
ENV_MAP[GPU_ARCHS]="gfx950"
|
||||
fi
|
||||
|
||||
# Parse -w/--workdir and -e ENV=VAL
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-w|--workdir)
|
||||
WORKDIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
-e)
|
||||
IFS="=" read -r key val <<< "$2"
|
||||
ENV_MAP["$key"]="$val"
|
||||
shift 2
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Build final ENV_ARGS
|
||||
ENV_ARGS=()
|
||||
for key in "${!ENV_MAP[@]}"; do
|
||||
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
|
||||
done
|
||||
|
||||
# Run docker exec
|
||||
docker exec \
|
||||
-w "$WORKDIR" \
|
||||
"${ENV_ARGS[@]}" \
|
||||
ci_sglang "$@"
|
||||
47
scripts/ci/amd_ci_install_dependency.sh
Executable file
47
scripts/ci/amd_ci_install_dependency.sh
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_ARCH="mi30x" # default
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_ARCH="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
||||
else
|
||||
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
||||
fi
|
||||
|
||||
# Install the required dependencies in CI.
|
||||
docker exec ci_sglang pip install --upgrade pip
|
||||
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
|
||||
|
||||
case "${GPU_ARCH}" in
|
||||
mi35x)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
|
||||
docker exec ci_sglang pip install -e "python[dev_hip]" --no-deps # TODO: only for mi35x
|
||||
# For lmms_evals evaluating MMMU
|
||||
docker exec -w / ci_sglang git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
docker exec -w /lmms-eval ci_sglang pip install -e . --no-deps # TODO: only for mi35x
|
||||
;;
|
||||
mi30x|mi300|mi325)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
|
||||
docker exec ci_sglang pip install -e "python[dev_hip]"
|
||||
# For lmms_evals evaluating MMMU
|
||||
docker exec -w / ci_sglang git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
docker exec -w /lmms-eval ci_sglang pip install -e .
|
||||
;;
|
||||
*)
|
||||
echo "Runner architecture '${GPU_ARCH}' unrecognised;" >&2
|
||||
;;
|
||||
esac
|
||||
|
||||
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
|
||||
docker exec -w /human-eval ci_sglang pip install -e .
|
||||
|
||||
docker exec -w / ci_sglang mkdir -p /dummy-grok
|
||||
mkdir -p dummy-grok && wget https://sharkpublic.blob.core.windows.net/sharkpublic/sglang/dummy_grok.json -O dummy-grok/config.json
|
||||
docker cp ./dummy-grok ci_sglang:/
|
||||
|
||||
docker exec ci_sglang pip install huggingface_hub[hf_xet]
|
||||
docker exec ci_sglang pip install pytest
|
||||
132
scripts/ci/amd_ci_start_container.sh
Executable file
132
scripts/ci/amd_ci_start_container.sh
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Get version from SGLang version.py file
|
||||
SGLANG_VERSION_FILE="$(dirname "$0")/../../python/sglang/version.py"
|
||||
SGLANG_VERSION="v0.5.0rc0" # Default version, will be overridden if version.py is found
|
||||
|
||||
if [ -f "$SGLANG_VERSION_FILE" ]; then
|
||||
VERSION_FROM_FILE=$(python3 -c '
|
||||
import re, sys
|
||||
with open(sys.argv[1], "r") as f:
|
||||
content = f.read()
|
||||
match = re.search(r"__version__\s*=\s*[\"'"'"'](.*?)[\"'"'"']", content)
|
||||
if match:
|
||||
print("v" + match.group(1))
|
||||
' "$SGLANG_VERSION_FILE" 2>/dev/null || echo "")
|
||||
|
||||
if [ -n "$VERSION_FROM_FILE" ]; then
|
||||
SGLANG_VERSION="$VERSION_FROM_FILE"
|
||||
echo "Using SGLang version from version.py: $SGLANG_VERSION"
|
||||
else
|
||||
echo "Warning: Could not parse version from $SGLANG_VERSION_FILE, using default: $SGLANG_VERSION" >&2
|
||||
fi
|
||||
else
|
||||
echo "Warning: version.py not found, using default version: $SGLANG_VERSION" >&2
|
||||
fi
|
||||
|
||||
|
||||
# Default base tags (can be overridden by command line arguments)
|
||||
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-rocm630-mi30x"
|
||||
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi35x"
|
||||
|
||||
# Parse command line arguments
|
||||
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
|
||||
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
|
||||
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [--mi30x-base-tag TAG] [--mi35x-base-tag TAG]"
|
||||
exit 0
|
||||
;;
|
||||
*) echo "Unknown option $1"; exit 1;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
|
||||
# Detect GPU architecture from the Kubernetes runner hostname
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_ARCH="mi30x" # default
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_ARCH="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
||||
else
|
||||
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
||||
fi
|
||||
|
||||
# Normalise / collapse architectures we don’t yet build specifically for
|
||||
case "${GPU_ARCH}" in
|
||||
mi35x)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
|
||||
;;
|
||||
mi30x|mi300|mi325)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
*)
|
||||
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
esac
|
||||
|
||||
|
||||
# Set up DEVICE_FLAG based on Kubernetes pod info
|
||||
if [[ -f /etc/podinfo/gha-render-devices ]]; then
|
||||
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
|
||||
else
|
||||
DEVICE_FLAG="--device /dev/dri"
|
||||
fi
|
||||
|
||||
|
||||
# Find the latest image
|
||||
find_latest_image() {
|
||||
local gpu_arch=$1
|
||||
local base_tag days_back image_tag
|
||||
|
||||
case "${gpu_arch}" in
|
||||
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
|
||||
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
|
||||
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
|
||||
esac
|
||||
|
||||
for days_back in {0..6}; do
|
||||
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
||||
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
|
||||
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
|
||||
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
|
||||
echo "rocm/sgl-dev:${image_tag}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
|
||||
echo "Using hard-coded fallback…" >&2
|
||||
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
||||
echo "rocm/sgl-dev:v0.5.0rc0-rocm700-mi35x-20250812"
|
||||
else
|
||||
echo "rocm/sgl-dev:v0.5.0rc0-rocm630-mi30x-20250812"
|
||||
fi
|
||||
}
|
||||
|
||||
# Pull and run the latest image
|
||||
IMAGE=$(find_latest_image "${GPU_ARCH}")
|
||||
echo "Pulling Docker image: ${IMAGE}"
|
||||
docker pull "${IMAGE}"
|
||||
|
||||
echo "Launching container: ci_sglang"
|
||||
docker run -dt --user root --device=/dev/kfd ${DEVICE_FLAG} \
|
||||
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
|
||||
--ipc=host --group-add video \
|
||||
--shm-size 32g \
|
||||
--cap-add=SYS_PTRACE \
|
||||
-e HF_TOKEN="${HF_TOKEN:-}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
-w /sglang-checkout \
|
||||
--name ci_sglang \
|
||||
"${IMAGE}"
|
||||
68
scripts/ci/ci_install_deepep.sh
Executable file
68
scripts/ci/ci_install_deepep.sh
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Install the dependency in CI.
|
||||
set -euxo pipefail
|
||||
|
||||
bash scripts/ci/ci_install_dependency.sh
|
||||
|
||||
export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
|
||||
export NVSHMEM_DIR=/opt/nvshmem/install
|
||||
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
|
||||
export PATH="${NVSHMEM_DIR}/bin:$PATH"
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
if python3 -c "import deep_ep" >/dev/null 2>&1; then
|
||||
echo "deep_ep is already installed or importable. Skipping installation."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Install system dependencies
|
||||
apt install -y curl wget git sudo libibverbs-dev rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake
|
||||
|
||||
# Install GDRCopy
|
||||
rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy
|
||||
rm -rf /opt/nvshmem && mkdir -p /opt/nvshmem
|
||||
cd /opt/gdrcopy
|
||||
git clone https://github.com/NVIDIA/gdrcopy.git .
|
||||
git checkout v2.4.4
|
||||
apt update
|
||||
apt install -y nvidia-dkms-535
|
||||
apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms
|
||||
apt install -y check libsubunit0 libsubunit-dev python3-venv
|
||||
cd packages
|
||||
CUDA=/usr/local/cuda ./build-deb-packages.sh
|
||||
dpkg -i gdrdrv-dkms_*.deb
|
||||
dpkg -i libgdrapi_*.deb
|
||||
dpkg -i gdrcopy-tests_*.deb
|
||||
dpkg -i gdrcopy_*.deb
|
||||
|
||||
if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then
|
||||
ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
|
||||
fi
|
||||
apt-get update && apt-get install -y libfabric-dev
|
||||
|
||||
# Install NVSHMEM
|
||||
cd /opt/nvshmem
|
||||
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz
|
||||
tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz
|
||||
mv nvshmem_src nvshmem && cd nvshmem
|
||||
NVSHMEM_SHMEM_SUPPORT=0 \
|
||||
NVSHMEM_UCX_SUPPORT=0 \
|
||||
NVSHMEM_USE_NCCL=0 \
|
||||
NVSHMEM_MPI_SUPPORT=0 \
|
||||
NVSHMEM_IBGDA_SUPPORT=1 \
|
||||
NVSHMEM_PMIX_SUPPORT=0 \
|
||||
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
|
||||
NVSHMEM_USE_GDRCOPY=1 \
|
||||
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90
|
||||
cd build
|
||||
make -j$(nproc) install
|
||||
|
||||
# Install DeepEP
|
||||
rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout b92d0d4860ce6866cd6d31bfbae937f9a7a3772b
|
||||
cd /root/.cache/deepep && python3 setup.py install
|
||||
|
||||
# Verify configuration
|
||||
echo "=== Verify GDRCOPY ==="
|
||||
gdrcopy_copybw
|
||||
echo "=== Verify NVSHMEM ==="
|
||||
nvshmem-info -a
|
||||
81
scripts/ci/ci_install_dependency.sh
Executable file
81
scripts/ci/ci_install_dependency.sh
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
# Install the dependency in CI.
|
||||
set -euxo pipefail
|
||||
|
||||
IS_BLACKWELL=${IS_BLACKWELL:-0}
|
||||
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
CU_VERSION="cu129"
|
||||
else
|
||||
CU_VERSION="cu126"
|
||||
fi
|
||||
|
||||
# Clear torch compilation cache
|
||||
python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)'
|
||||
|
||||
# Kill existing processes
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
bash "${SCRIPT_DIR}/../killall_sglang.sh"
|
||||
|
||||
# Install apt packages
|
||||
apt install -y git libnuma-dev
|
||||
|
||||
# Install uv
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
# The blackwell CI runner has some issues with pip and uv,
|
||||
# so we can only use pip with `--break-system-packages`
|
||||
PIP_CMD="pip"
|
||||
PIP_INSTALL_SUFFIX="--break-system-packages"
|
||||
|
||||
# Clean up existing installations
|
||||
$PIP_CMD uninstall -y flashinfer_python sgl-kernel sglang vllm $PIP_INSTALL_SUFFIX || true
|
||||
else
|
||||
# In normal cases, we use uv, which is much faster than pip.
|
||||
pip install --upgrade pip
|
||||
pip install uv
|
||||
export UV_SYSTEM_PYTHON=true
|
||||
|
||||
PIP_CMD="uv pip"
|
||||
PIP_INSTALL_SUFFIX="--index-strategy unsafe-best-match"
|
||||
|
||||
# Clean up existing installations
|
||||
$PIP_CMD uninstall flashinfer_python sgl-kernel sglang vllm || true
|
||||
fi
|
||||
|
||||
# Install the main package
|
||||
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Install router for pd-disagg test
|
||||
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
|
||||
|
||||
|
||||
SGL_KERNEL_VERSION=0.3.9.post2
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
# TODO auto determine sgl-kernel version
|
||||
$PIP_CMD install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall $PIP_INSTALL_SUFFIX
|
||||
else
|
||||
$PIP_CMD install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu124-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
|
||||
# Show current packages
|
||||
$PIP_CMD list
|
||||
|
||||
# Install additional dependencies
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.5 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||
|
||||
if [ "$IS_BLACKWELL" != "1" ]; then
|
||||
# For lmms_evals evaluating MMMU
|
||||
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
$PIP_CMD install -e lmms-eval/ $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Install xformers
|
||||
$PIP_CMD install xformers --index-url https://download.pytorch.org/whl/${CU_VERSION} --no-deps $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
|
||||
# Install FlashMLA for attention backend tests
|
||||
# $PIP_CMD install git+https://github.com/deepseek-ai/FlashMLA.git $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Show current packages
|
||||
$PIP_CMD list
|
||||
|
||||
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}"
|
||||
24
scripts/ci/ci_install_rust.sh
Executable file
24
scripts/ci/ci_install_rust.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# Check if sudo is available
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libssl-dev pkg-config protobuf-compiler
|
||||
else
|
||||
apt-get update
|
||||
apt-get install -y libssl-dev pkg-config protobuf-compiler
|
||||
fi
|
||||
|
||||
# Install rustup (Rust installer and version manager)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
|
||||
# Follow the installation prompts, then reload your shell
|
||||
. "$HOME/.cargo/env"
|
||||
source $HOME/.cargo/env
|
||||
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
protoc --version
|
||||
94
scripts/ci/ci_start_disaggregation_servers.sh
Executable file
94
scripts/ci/ci_start_disaggregation_servers.sh
Executable file
@@ -0,0 +1,94 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Function to find the first available active IB device
|
||||
find_active_ib_device() {
|
||||
for device in mlx5_{0..11}; do
|
||||
if ibv_devinfo $device >/dev/null 2>&1; then
|
||||
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
|
||||
if [[ "$state" == "PORT_ACTIVE" ]]; then
|
||||
echo "$device"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
echo "No active IB device found" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Get the first available active IB device
|
||||
DEVICE=$(find_active_ib_device)
|
||||
echo "Using IB device: $DEVICE"
|
||||
|
||||
# Launch prefill servers on GPU 0–3
|
||||
for i in {0..3}; do
|
||||
PORT=$((30001 + i))
|
||||
BOOTSTRAP_PORT=$((9001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode prefill \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
|
||||
done
|
||||
|
||||
# Launch decode servers on GPU 4–7
|
||||
for i in {4..7}; do
|
||||
PORT=$((30001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode decode \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--base-gpu-id 0 &
|
||||
done
|
||||
|
||||
# Wait for disaggregation servers to initialize
|
||||
echo "Waiting for disaggregation servers to initialize..."
|
||||
|
||||
# Health check with 5-minute timeout
|
||||
TIMEOUT=300
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
echo "Checking health of all 8 servers..."
|
||||
while true; do
|
||||
CURRENT_TIME=$(date +%s)
|
||||
ELAPSED=$((CURRENT_TIME - START_TIME))
|
||||
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HEALTHY_COUNT=0
|
||||
# Check all 8 servers (127.0.0.1-8:30001-30008)
|
||||
for i in {1..8}; do
|
||||
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
|
||||
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
|
||||
|
||||
if [ $HEALTHY_COUNT -eq 8 ]; then
|
||||
echo "✅ All 8 servers are healthy!"
|
||||
break
|
||||
else
|
||||
sleep 10 # Wait 10 seconds before next check
|
||||
fi
|
||||
done
|
||||
|
||||
# Don't launch router here - just keep servers running
|
||||
echo "✅ All disaggregation servers are ready and waiting for router connections"
|
||||
|
||||
# Keep the script running
|
||||
wait
|
||||
61
scripts/ci/npu_ci_install_dependency.sh
Executable file
61
scripts/ci/npu_ci_install_dependency.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
PIP_INSTALL="pip install --no-cache-dir"
|
||||
|
||||
|
||||
# Install the required dependencies in CI.
|
||||
apt update -y && apt install -y \
|
||||
build-essential \
|
||||
cmake \
|
||||
wget \
|
||||
curl \
|
||||
net-tools \
|
||||
zlib1g-dev \
|
||||
lld \
|
||||
clang \
|
||||
locales \
|
||||
ccache \
|
||||
ca-certificates
|
||||
update-ca-certificates
|
||||
python3 -m ${PIP_INSTALL} --upgrade pip
|
||||
|
||||
|
||||
### Download MemFabricV2
|
||||
MF_WHL_NAME="mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl"
|
||||
MEMFABRIC_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${MF_WHL_NAME}"
|
||||
wget -O "${MF_WHL_NAME}" "${MEMFABRIC_URL}" && ${PIP_INSTALL} "./${MF_WHL_NAME}"
|
||||
|
||||
|
||||
### Install vLLM
|
||||
VLLM_TAG=v0.8.5
|
||||
git clone --depth 1 https://github.com/vllm-project/vllm.git --branch $VLLM_TAG
|
||||
(cd vllm && VLLM_TARGET_DEVICE="empty" ${PIP_INSTALL} -v -e .)
|
||||
|
||||
|
||||
### Install PyTorch and PTA
|
||||
PYTORCH_VERSION=2.6.0
|
||||
TORCHVISION_VERSION=0.21.0
|
||||
${PIP_INSTALL} torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
PTA_VERSION="v7.1.0.1-pytorch2.6.0"
|
||||
PTA_NAME="torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl"
|
||||
PTA_URL="https://gitee.com/ascend/pytorch/releases/download/${PTA_VERSION}/${PTA_NAME}"
|
||||
wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}"
|
||||
|
||||
|
||||
### Install Triton-Ascend
|
||||
TRITON_ASCEND_NAME="triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
|
||||
TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${TRITON_ASCEND_NAME}"
|
||||
${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11
|
||||
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
|
||||
|
||||
|
||||
### Install sgl-kernel-npu
|
||||
SGL_KERNEL_NPU_TAG="20250901"
|
||||
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
|
||||
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
|
||||
|
||||
|
||||
### Install SGLang
|
||||
${PIP_INSTALL} -v -e "python[srt_npu]"
|
||||
293
scripts/code_sync/copy_from_oss.py
Normal file
293
scripts/code_sync/copy_from_oss.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Sync code from OSS repo to the local repo and open a PR if changes exist.
|
||||
|
||||
NOTE:
|
||||
1. You need to execute this script in the git root folder.
|
||||
2. A GH_TOKEN environment variable is required to create the pull request.
|
||||
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
|
||||
|
||||
This script will:
|
||||
1. Clone the sgl-project/sglang repository (or use a local copy).
|
||||
2. Sync specified files and directories using rsync.
|
||||
3. Check if the sync operation resulted in any changes.
|
||||
4. If there are changes:
|
||||
a. Create a new branch.
|
||||
b. Commit and push the changes.
|
||||
c. Open a pull request using the GitHub CLI (gh).
|
||||
|
||||
Usage:
|
||||
# Run the full sync and PR creation process
|
||||
python3 scripts/copy_from_oss.py
|
||||
|
||||
# Perform a dry run without making any actual changes
|
||||
python3 scripts/copy_from_oss.py --dry-run
|
||||
|
||||
# Use a local directory as the source instead of cloning
|
||||
python3 scripts/copy_from_oss.py --local-dir ~/projects/sglang
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# --- Configuration Begin ---
|
||||
# List of folders and files to copy from the OSS repo.
|
||||
# Changes outside these paths will be ignored.
|
||||
folder_names = [
|
||||
"3rdparty",
|
||||
"assets",
|
||||
"benchmark",
|
||||
"docker",
|
||||
"docs",
|
||||
"examples",
|
||||
"sgl-kernel",
|
||||
"README.md",
|
||||
"python/sglang/lang",
|
||||
"python/sglang/srt",
|
||||
"python/sglang/test",
|
||||
"test/lang",
|
||||
"test/srt",
|
||||
]
|
||||
|
||||
private_repo = "your-org/sglang-private-repo"
|
||||
# --- Configuration End ---
|
||||
|
||||
|
||||
def write_github_step_summary(content):
|
||||
if not os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||
return
|
||||
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check for required command-line tools."""
|
||||
if not shutil.which("git"):
|
||||
raise EnvironmentError("git is not installed or not in PATH.")
|
||||
if not shutil.which("gh"):
|
||||
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
|
||||
print("✅ All dependencies (git, gh) are available.")
|
||||
|
||||
|
||||
def checkout_main(dry_run):
|
||||
"""Checkout to the main branch."""
|
||||
commands = [
|
||||
"git checkout main",
|
||||
"git reset --hard",
|
||||
]
|
||||
for cmd in commands:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(cmd, shell=True, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr.decode()}")
|
||||
raise
|
||||
print("✅ Checkout the main branch.")
|
||||
|
||||
|
||||
def get_source_folder(args):
|
||||
"""
|
||||
Prepare the source repository, either by cloning from GitHub or using a local directory.
|
||||
Returns the path to the source repo root, a temporary directory path (if created),
|
||||
and the short commit hash.
|
||||
"""
|
||||
temp_dir = None
|
||||
if args.local_dir:
|
||||
oss_root = os.path.expanduser(args.local_dir)
|
||||
if not os.path.exists(oss_root):
|
||||
raise FileNotFoundError(
|
||||
f"Specified local directory {oss_root} does not exist."
|
||||
)
|
||||
print(f"Using local directory as the source: {oss_root}")
|
||||
else:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
oss_root = temp_dir
|
||||
print(f"Created temporary directory: {oss_root}")
|
||||
|
||||
repo_url = "https://github.com/sgl-project/sglang.git"
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--single-branch",
|
||||
"--branch",
|
||||
"main",
|
||||
repo_url,
|
||||
temp_dir,
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
print(f"Successfully cloned repository to {temp_dir}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error cloning repository: {e.stderr.decode()}")
|
||||
raise
|
||||
|
||||
commit_hash = subprocess.run(
|
||||
["git", "-C", oss_root, "rev-parse", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()[:8]
|
||||
print(f"✅ Get source OSS code at commit: {commit_hash}")
|
||||
return oss_root, temp_dir, commit_hash
|
||||
|
||||
|
||||
def sync_directories(oss_root, folder_names, dry_run):
|
||||
"""Sync specified directories from oss_root to current working directory."""
|
||||
rsync_commands = []
|
||||
for folder_name in folder_names:
|
||||
target_name = f"{oss_root}/{folder_name}"
|
||||
src_name = "./" + "/".join(folder_name.split("/")[:-1])
|
||||
cmd = f"rsync -r --delete {target_name} {src_name}"
|
||||
rsync_commands.append(cmd)
|
||||
|
||||
for cmd in rsync_commands:
|
||||
try:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command '{cmd}': {e}")
|
||||
raise
|
||||
print(f"✅ Sync all folders.")
|
||||
|
||||
|
||||
def check_for_changes():
|
||||
"""Check if there are any uncommitted git changes."""
|
||||
# This command exits with 1 if there are changes, 0 otherwise.
|
||||
result = subprocess.run(["git", "diff", "--quiet"])
|
||||
return result.returncode != 0
|
||||
|
||||
|
||||
def create_and_push_branch(branch_name, commit_message, dry_run):
|
||||
"""Create a new branch, commit all changes, and push to origin."""
|
||||
commands = [
|
||||
f"git checkout -b {branch_name}",
|
||||
"git config user.name 'github-actions[bot]'",
|
||||
"git config user.email 'github-actions[bot]@users.noreply.github.com'",
|
||||
"git add .",
|
||||
f"git commit -m '{commit_message}'",
|
||||
f"git push origin {branch_name} --force",
|
||||
]
|
||||
print("\nCreating and pushing git branch...")
|
||||
for cmd in commands:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(cmd, shell=True, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr.decode()}")
|
||||
raise
|
||||
|
||||
|
||||
def create_pull_request(branch_name, title, body, dry_run):
|
||||
"""Create a pull request using the GitHub CLI."""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print(
|
||||
"\n⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
|
||||
)
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
print("\nCreating pull request...")
|
||||
command = [
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--head",
|
||||
branch_name,
|
||||
"--repo",
|
||||
private_repo,
|
||||
"--title",
|
||||
title,
|
||||
"--body",
|
||||
body,
|
||||
]
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
env = os.environ.copy()
|
||||
env["GH_TOKEN"] = gh_token
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command, check=True, capture_output=True, text=True, env=env
|
||||
)
|
||||
pr_url = result.stdout.strip()
|
||||
msg = f"✅ Successfully created pull request: {pr_url}"
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating pull request: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Copy code from OSS and open a PR if changes are detected."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
help="Path to local SGLang directory to use instead of cloning from GitHub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Dry run the script without executing git, rsync, or gh commands.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dependencies()
|
||||
checkout_main(args.dry_run)
|
||||
|
||||
oss_root, temp_dir, oss_commit = get_source_folder(args)
|
||||
|
||||
try:
|
||||
# Sync directories
|
||||
sync_directories(oss_root, folder_names, args.dry_run)
|
||||
|
||||
# Check for changes and create PR if necessary
|
||||
if not check_for_changes():
|
||||
msg = "😴 No changes detected. The code is already in sync."
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
return
|
||||
|
||||
print("✅ Changes detected. Proceeding to create a PR.")
|
||||
|
||||
current_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
branch_name = f"copy-from-oss-{oss_commit}-{current_date}"
|
||||
commit_message = f"Copy OSS code from {oss_commit} on {current_date}"
|
||||
pr_title = (
|
||||
f"[Automated PR] Copy OSS code from commit {oss_commit} on {current_date}"
|
||||
)
|
||||
pr_body = (
|
||||
f"Copy OSS code from https://github.com/sgl-project/sglang/commit/{oss_commit} on {current_date}."
|
||||
"\n\n---\n\n"
|
||||
"*This is an automated PR created by scripts/copy_from_oss.py.*"
|
||||
)
|
||||
|
||||
create_and_push_branch(branch_name, commit_message, args.dry_run)
|
||||
create_pull_request(branch_name, pr_title, pr_body, args.dry_run)
|
||||
|
||||
finally:
|
||||
# Remove temporary directory if it was created
|
||||
if temp_dir:
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"\nRemoved temporary directory: {temp_dir}")
|
||||
except OSError as e:
|
||||
print(f"Error removing temporary directory {temp_dir}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
425
scripts/code_sync/copy_to_oss.py
Normal file
425
scripts/code_sync/copy_to_oss.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Sync a specific commit from the local private repo to the OSS upstream and open a PR.
|
||||
|
||||
NOTE:
|
||||
1. You need to execute this script in the git root folder.
|
||||
2. A GH_TOKEN environment variable is required to create the pull request.
|
||||
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
|
||||
|
||||
This script will:
|
||||
1. Take a commit hash as an argument (or use the latest commit by default).
|
||||
2. Create a patch for that commit.
|
||||
3. Filter the patch to only include changes in specified directories.
|
||||
4. Clone the sgl-project/sglang repository.
|
||||
5. Create a new branch in the OSS repo.
|
||||
6. Apply the filtered patch, commit, and force push.
|
||||
7. Open a pull request to the OSS repo using the GitHub CLI (gh).
|
||||
|
||||
Usage:
|
||||
# Sync the latest commit from the current branch
|
||||
python3 scripts/copy_to_oss.py
|
||||
|
||||
# Run the full sync and PR creation process for a given commit
|
||||
python3 scripts/copy_to_oss.py --commit <commit_hash>
|
||||
|
||||
# Perform a dry run without making any actual changes
|
||||
python3 scripts/copy_to_oss.py --commit <commit_hash> --dry-run
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# --- Configuration Begin ---
|
||||
# List of folders and files to copy to the OSS repo.
|
||||
# Changes outside these paths will be ignored.
|
||||
folder_names = [
|
||||
"3rdparty",
|
||||
"assets",
|
||||
"benchmark",
|
||||
"docker",
|
||||
"docs",
|
||||
"examples",
|
||||
"sgl-kernel",
|
||||
"README.md",
|
||||
"python/sglang/lang",
|
||||
"python/sglang/srt",
|
||||
"python/sglang/test",
|
||||
"test/lang",
|
||||
"test/srt",
|
||||
]
|
||||
|
||||
# --- Configuration End ---
|
||||
|
||||
|
||||
def write_github_step_summary(content):
|
||||
if not os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||
return
|
||||
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def get_commit_info(commit_ref):
|
||||
"""
|
||||
Retrieves the hash and message of a specific commit.
|
||||
|
||||
Args:
|
||||
commit_ref (str): The commit hash, tag, or branch to inspect (e.g., 'HEAD').
|
||||
|
||||
Returns:
|
||||
A tuple containing the (commit_hash, commit_message),
|
||||
or (None, None) if an error occurs.
|
||||
"""
|
||||
try:
|
||||
# Use a custom format to get the hash (%H) and the full message (%B)
|
||||
# separated by a null character for safe parsing.
|
||||
command = ["git", "log", "-1", f"--pretty=%H%x00%B", commit_ref]
|
||||
result = subprocess.run(
|
||||
command, capture_output=True, text=True, check=True, encoding="utf-8"
|
||||
)
|
||||
|
||||
# Split the output by the null character separator
|
||||
commit_hash, commit_message = result.stdout.strip().split("\x00", 1)
|
||||
return commit_hash, commit_message
|
||||
|
||||
except FileNotFoundError:
|
||||
print("❌ Error: 'git' command not found. Is Git installed and in your PATH?")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error getting commit info for '{commit_ref}': {e.stderr.strip()}")
|
||||
print(
|
||||
"Hint: Make sure you are running this from within a Git repository and the commit exists."
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check for required command-line tools."""
|
||||
if not shutil.which("git"):
|
||||
raise EnvironmentError("git is not installed or not in PATH.")
|
||||
if not shutil.which("gh"):
|
||||
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
|
||||
print("✅ All dependencies (git, gh) are available.")
|
||||
|
||||
|
||||
def create_filtered_patch(commit_hash, dry_run):
|
||||
"""
|
||||
Create a patch file for the given commit, containing only changes
|
||||
to files and directories specified in `folder_names`.
|
||||
"""
|
||||
print(f"Creating a filtered patch for commit {commit_hash}")
|
||||
|
||||
try:
|
||||
# Get the list of all files changed in the commit
|
||||
changed_files_raw = subprocess.run(
|
||||
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout
|
||||
changed_files = changed_files_raw.strip().split("\n")
|
||||
|
||||
# Filter the list of files
|
||||
relevant_files = [
|
||||
f for f in changed_files if any(f.startswith(path) for path in folder_names)
|
||||
]
|
||||
|
||||
if not relevant_files:
|
||||
msg = "\n😴 No relevant file changes found in this commit. Exiting."
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
return None, None
|
||||
|
||||
print("Found relevant changes in the following files:")
|
||||
for f in relevant_files:
|
||||
print(f" - {f}")
|
||||
|
||||
# Create a patch containing only the changes for the relevant files
|
||||
patch_command = [
|
||||
"git",
|
||||
"format-patch",
|
||||
"--stdout",
|
||||
f"{commit_hash}^..{commit_hash}",
|
||||
"--",
|
||||
] + relevant_files
|
||||
|
||||
print(f"Run: {' '.join(patch_command)}")
|
||||
|
||||
patch_content = subprocess.run(
|
||||
patch_command, capture_output=True, text=True, check=True
|
||||
).stdout
|
||||
|
||||
# Save the patch to a temporary file
|
||||
patch_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".patch", encoding="utf-8"
|
||||
)
|
||||
patch_file.write(patch_content)
|
||||
patch_file.close()
|
||||
|
||||
print(f"✅ Filtered patch created successfully at: {patch_file.name}")
|
||||
return patch_file.name, relevant_files
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating patch: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def get_oss_repo(dry_run):
|
||||
"""
|
||||
Clones the OSS repository into a temporary directory.
|
||||
Returns the path to the repo root and the temp directory itself.
|
||||
"""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print("⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation.")
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
oss_root = os.path.join(temp_dir, "sglang")
|
||||
print(f"\nCreated temporary directory for OSS repo: {temp_dir}")
|
||||
|
||||
repo_url = f"https://{gh_token}@github.com/sgl-project/sglang.git"
|
||||
command = ["git", "clone", "--branch", "main", repo_url, oss_root]
|
||||
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(command, check=True, capture_output=True)
|
||||
print(f"✅ Successfully cloned repository to {oss_root}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error cloning repository: {e.stderr.decode()}")
|
||||
shutil.rmtree(temp_dir)
|
||||
raise
|
||||
|
||||
return oss_root, temp_dir
|
||||
|
||||
|
||||
def apply_patch_and_push(oss_root, patch_file, branch_name, commit_message, dry_run):
|
||||
"""
|
||||
In the OSS repo, create a branch, apply the patch, commit, and push.
|
||||
"""
|
||||
print("\nApplying patch and pushing to OSS repo...")
|
||||
|
||||
original_cwd = os.getcwd()
|
||||
if not dry_run:
|
||||
os.chdir(oss_root)
|
||||
|
||||
try:
|
||||
# Define commands as lists to avoid shell injection issues
|
||||
commands_to_run = [
|
||||
["git", "checkout", "-b", branch_name],
|
||||
["git", "apply", patch_file],
|
||||
["git", "config", "user.name", "github-actions[bot]"],
|
||||
[
|
||||
"git",
|
||||
"config",
|
||||
"user.email",
|
||||
"github-actions[bot]@users.noreply.github.com",
|
||||
],
|
||||
["git", "add", "."],
|
||||
]
|
||||
|
||||
for cmd_list in commands_to_run:
|
||||
print(f"Run: {' '.join(cmd_list)}")
|
||||
if not dry_run:
|
||||
subprocess.run(cmd_list, check=True, capture_output=True, text=True)
|
||||
|
||||
# Handle commit separately to pass multi-line message safely via stdin
|
||||
commit_cmd = ["git", "commit", "-F", "-"]
|
||||
print(f"Run: {' '.join(commit_cmd)}")
|
||||
if not dry_run:
|
||||
print(f"Commit Message:\n---\n{commit_message}\n---")
|
||||
subprocess.run(
|
||||
commit_cmd,
|
||||
input=commit_message,
|
||||
text=True,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Push the changes
|
||||
push_cmd = ["git", "push", "origin", branch_name, "--force"]
|
||||
print(f"Run: {' '.join(push_cmd)}")
|
||||
if not dry_run:
|
||||
subprocess.run(push_cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr}")
|
||||
raise
|
||||
finally:
|
||||
if not dry_run:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
print("✅ Branch created, patch applied, and pushed successfully.")
|
||||
|
||||
|
||||
def create_pull_request(oss_root, branch_name, title, body, dry_run):
|
||||
"""Create a pull request in the OSS repo using the GitHub CLI."""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print("⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation.")
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
print("\nCreating pull request...")
|
||||
command = [
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--head",
|
||||
branch_name,
|
||||
"--repo",
|
||||
"sgl-project/sglang",
|
||||
"--title",
|
||||
title,
|
||||
"--body",
|
||||
body,
|
||||
]
|
||||
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
env = os.environ.copy()
|
||||
env["GH_TOKEN"] = gh_token
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
cwd=oss_root,
|
||||
)
|
||||
msg = f"✅ Successfully created pull request: {result.stdout.strip()}"
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating pull request: {e.stderr}")
|
||||
# Check if a PR already exists
|
||||
if "A pull request for" in e.stderr and "already exists" in e.stderr:
|
||||
print("ℹ️ A PR for this branch likely already exists.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def get_commit_author(commit_hash):
|
||||
"""Get the author name and email of a commit."""
|
||||
try:
|
||||
author_name = subprocess.run(
|
||||
["git", "show", "-s", "--format=%an", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
author_email = subprocess.run(
|
||||
["git", "show", "-s", "--format=%ae", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
return author_name, author_email
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting commit author for {commit_hash}: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Copy a commit from the private repo to OSS and open a PR."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit",
|
||||
type=str,
|
||||
default="LAST",
|
||||
help="The commit hash to sync. Defaults to 'LAST' to use the latest commit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Dry run the script without executing git, rsync, or gh commands.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dependencies()
|
||||
|
||||
commit_ref = "HEAD" if args.commit == "LAST" else args.commit
|
||||
commit_hash, original_commit_message = get_commit_info(commit_ref)
|
||||
|
||||
if not commit_hash:
|
||||
return # Exit if we couldn't get commit info
|
||||
|
||||
# Display the details of the commit being processed
|
||||
if args.commit == "LAST":
|
||||
summary = (
|
||||
f"\nℹ️ No commit specified. Using the last commit:\n"
|
||||
f" - **Hash:** `{commit_hash}`\n"
|
||||
f" - **Message:** {original_commit_message}\n\n"
|
||||
)
|
||||
else:
|
||||
summary = (
|
||||
f"\nℹ️ Using specified commit:\n"
|
||||
f" - **Hash:** `{commit_hash}`\n"
|
||||
f" - **Message:** {original_commit_message}\n\n"
|
||||
)
|
||||
print(summary)
|
||||
write_github_step_summary(summary)
|
||||
|
||||
short_hash = commit_hash[:8]
|
||||
|
||||
patch_file = None
|
||||
temp_dir = None
|
||||
try:
|
||||
# 1. Create a filtered patch from the local repo
|
||||
patch_file, relevant_files = create_filtered_patch(commit_hash, args.dry_run)
|
||||
if not patch_file:
|
||||
return
|
||||
|
||||
# 2. Get the OSS repo
|
||||
oss_root, temp_dir = get_oss_repo(args.dry_run)
|
||||
|
||||
# 3. Get original commit author for the co-author line
|
||||
author_name, author_email = get_commit_author(commit_hash)
|
||||
|
||||
# 4. Prepare content for the commit and PR based on changed files
|
||||
file_list_str = "\n".join([f"- {f}" for f in relevant_files])
|
||||
filename_list_str = ", ".join([f.split("/")[-1] for f in relevant_files])
|
||||
if len(filename_list_str) > 40:
|
||||
filename_list_str = filename_list_str[:40] + "..."
|
||||
current_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
pr_title = f"[Auto Sync] Update {filename_list_str} ({current_date})"
|
||||
pr_body = (
|
||||
f"Sync changes from commit `{short_hash}`.\n\n"
|
||||
f"**Relevant Files Changed:**\n{file_list_str}"
|
||||
"\n\n---\n\n"
|
||||
"*This is an automated PR created by a script.*"
|
||||
)
|
||||
|
||||
# 5. Create branch, apply patch, and push
|
||||
branch_name = f"sync-{short_hash}-{current_date}"
|
||||
co_author_line = f"Co-authored-by: {author_name} <{author_email}>"
|
||||
commit_message = f"{pr_title}\n\n{co_author_line}"
|
||||
apply_patch_and_push(
|
||||
oss_root, patch_file, branch_name, commit_message, args.dry_run
|
||||
)
|
||||
|
||||
# 6. Create Pull Request
|
||||
create_pull_request(oss_root, branch_name, pr_title, pr_body, args.dry_run)
|
||||
|
||||
finally:
|
||||
# Cleanup temporary files
|
||||
if patch_file and os.path.exists(patch_file):
|
||||
os.remove(patch_file)
|
||||
print(f"\nRemoved temporary patch file: {patch_file}")
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Removed temporary directory: {temp_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
scripts/code_sync/guideline.md
Normal file
27
scripts/code_sync/guideline.md
Normal file
@@ -0,0 +1,27 @@
|
||||
### Sync Code Between OSS and Private Fork
|
||||
|
||||
You can use the following principles and tools to sync the code between a private fork and the OSS repo [sgl-project/sglang](https://github.com/sgl-project/sglang/tree/main).
|
||||
It learns from [Copybara](https://github.com/google/copybara), a tool used at Google for maintaining open-source code synchronization.
|
||||
|
||||
## Principals
|
||||
|
||||
- The core folders (e.g., `python/sglang/srt`) are 100% mirrored between the private fork and OSS repo.
|
||||
- The OSS repo is the single source of truth. If one commit changes `python/sglang/srt` in the private repo, the change should be synced to the OSS repo as soon as possible with the action B below.
|
||||
- The common code (e.g., base classes, well-known techniques in the industry without private secrets) goes to `python/sglang/srt`. The private-specific code (e.g., with private-specific features, confidential info) goes to `python/sglang/private` .
|
||||
- Anytime you want to make private changes to a file or class under `python/sglang/srt`, duplicate the file and move it under `python/sglang/private`. You can achieve code reuse by importing and inheriting.
|
||||
|
||||
## How to sync the code bidirectionally
|
||||
### Action A: Copy code from OSS to private
|
||||
|
||||
- We can run this action: [Open A PR to Copy Code From OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-from-oss.yml)
|
||||
- It opens a PR to copy all files under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) from the OSS main branch to the private fork.
|
||||
- Since the OSS repo is the single source of truth, this action copies files and overwrites any changes in the private fork. To prevent the private changes from being overwritten, you need to ensure all private changes are merged into the OSS repo before running this action.
|
||||
- This action will be run automatically every day and can also be triggered manually.
|
||||
|
||||
### Action B: Copy diff from private to OSS
|
||||
|
||||
- We can run this action: [Open A PR to Copy Code To OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-to-oss.yml)
|
||||
- It opens a PR to apply the diff of one specific commit of the private fork to the OSS main branch. It will only pick the changes under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) and ignore changes under private folders (e.g., `python/sglang/private` )
|
||||
- For example, you can have a PR that changes both `python/sglang/srt` and `python/sglang/private/srt`. Once you merge the PR into the private repo, `python/sglang/srt` becomes desynced between the two repos. You need to run this action on your merge commit immediately to open a PR to send your diff to the OSS repo. Then, we need to merge the OSS PR as soon as possible. Once your OSS PR is merged, we can run action A again.
|
||||
- Action A copies files directly, but Action B applies diff. This is because OSS is the source of truth; action A can just copy files. Action B cannot copy, so it uses diff instead.
|
||||
- This action currently needs a manual trigger in order to prevent incidental code leaks. One can also consider making it automatic.
|
||||
18
scripts/code_sync/install_github_cli.sh
Executable file
18
scripts/code_sync/install_github_cli.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check if gh is installed before attempting to install it
|
||||
if ! command -v gh &> /dev/null
|
||||
then
|
||||
echo "GitHub CLI not found. Installing now..."
|
||||
(type -p wget >/dev/null || ( apt update && apt install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
|
||||
&& cat $out | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& mkdir -p -m 755 /etc/apt/sources.list.d \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
&& apt update \
|
||||
&& apt install gh -y
|
||||
else
|
||||
echo "GitHub CLI is already installed. Skipping installation."
|
||||
fi
|
||||
38
scripts/deprecated/convert_yi_vl.py
Normal file
38
scripts/deprecated/convert_yi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Convert Yi-VL config into a format usable with SGLang
|
||||
|
||||
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
|
||||
def add_image_token(model_path: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
|
||||
|
||||
print(tokenizer)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
|
||||
def edit_model_config(model_path):
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
setattr(config, "architectures", ["YiVLForCausalLM"])
|
||||
setattr(config, "image_token_index", 64002)
|
||||
|
||||
print(config)
|
||||
config.save_pretrained(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-path", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
add_image_token(args.model_path)
|
||||
edit_model_config(args.model_path)
|
||||
13
scripts/deprecated/convert_yi_vl.sh
Normal file
13
scripts/deprecated/convert_yi_vl.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
# For 34B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-34B
|
||||
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B
|
||||
|
||||
# For 6B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-6B
|
||||
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B
|
||||
9
scripts/deprecated/test_curl.sh
Normal file
9
scripts/deprecated/test_curl.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
curl http://localhost:30000/generate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Once upon a time,",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 64,
|
||||
"temperature": 0
|
||||
}
|
||||
}'
|
||||
217
scripts/deprecated/test_flashinfer.py
Normal file
217
scripts/deprecated/test_flashinfer.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.utils import should_use_tensor_core
|
||||
|
||||
flashinfer_prefill_wrapper = None
|
||||
flashinfer_decode_wrapper = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
||||
@pytest.mark.parametrize("kv_len", [54, 97])
|
||||
@pytest.mark.parametrize("qo_len", [37, 17])
|
||||
@pytest.mark.parametrize("num_kv_heads", [4])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32, 4])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_prefill_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
qo_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
||||
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
v_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
o_triton = torch.empty_like(q)
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
|
||||
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
max_len_extend = qo_len
|
||||
|
||||
extend_attention_fwd(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_triton,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
None, # b_start_loc = None
|
||||
b_seq_len,
|
||||
None, # b_seq_len_prefix = None
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
o_redundant = torch.empty_like(q)
|
||||
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
|
||||
b_seq_len_prefix = b_seq_len - b_seq_len_extend
|
||||
|
||||
redundant_attention(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
|
||||
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
flashinfer_prefill_wrapper.end_forward()
|
||||
|
||||
flashinfer_prefill_wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
o = flashinfer_prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 17, 37])
|
||||
@pytest.mark.parametrize("kv_len", [54, 127, 537])
|
||||
@pytest.mark.parametrize("num_kv_heads", [32])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_decode_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
||||
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
o_triton = torch.empty_like(q)
|
||||
req_to_token = (
|
||||
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
|
||||
)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
other_kv_index = 0
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_triton,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type="float16",
|
||||
)
|
||||
o = flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
||||
|
||||
|
||||
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||
use_tensor_cores = should_use_tensor_core(
|
||||
torch.half, num_attention_heads, num_kv_heads
|
||||
)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
|
||||
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
|
||||
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
|
||||
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
||||
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
|
||||
|
||||
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def main(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
task1 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
delay=1,
|
||||
)
|
||||
|
||||
task2 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of the United Kindom is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(task1, task2)
|
||||
print(rets)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
55
scripts/deprecated/test_httpserver_decode.py
Normal file
55
scripts/deprecated/test_httpserver_decode.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode(url)
|
||||
test_decode(url, n=3)
|
||||
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
test_decode(
|
||||
url,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
68
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
68
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode_stream.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode_stream(url, return_logprob, top_logprobs_num):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
|
||||
if return_logprob:
|
||||
assert data["meta_info"]["input_token_logprobs"] is not None
|
||||
assert data["meta_info"]["output_token_logprobs"] is not None
|
||||
for logprob, token_id, token_text in data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][prev:]:
|
||||
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
||||
prev = len(data["meta_info"]["output_token_logprobs"])
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode_stream(url, False, 0)
|
||||
test_decode_stream(url, True, 0)
|
||||
test_decode_stream(url, True, 3)
|
||||
88
scripts/deprecated/test_httpserver_llava.py
Normal file
88
scripts/deprecated/test_httpserver_llava.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
python3 test_httpserver_llava.py
|
||||
|
||||
Output:
|
||||
The image features a man standing on the back of a yellow taxi cab, holding
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = []
|
||||
for i in range(8):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_concurrent(args))
|
||||
|
||||
test_streaming(args)
|
||||
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is Paris.\nThe capital of the United States is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
138
scripts/deprecated/test_jump_forward.py
Normal file
138
scripts/deprecated/test_jump_forward.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
ip_jump_forward = (
|
||||
r"The google's DNS sever address is "
|
||||
+ IP_REGEX
|
||||
+ r" and "
|
||||
+ IP_REGEX
|
||||
+ r". "
|
||||
+ r"The google's website domain name is "
|
||||
+ r"www\.(\w)+\.(\w)+"
|
||||
+ r"."
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def regex_gen(s):
|
||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||
s += "A: " + sgl.gen(
|
||||
"answer",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=ip_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
json_jump_forward = (
|
||||
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
||||
+ r"""\n\{\n"""
|
||||
+ r""" "name": "[\w\d\s]*",\n"""
|
||||
+ r""" "country": "[\w\d\s]*",\n"""
|
||||
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
||||
+ r""" "population": [-+]?[0-9]+,\n"""
|
||||
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
||||
+ r"""\}\n"""
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_gen(s):
|
||||
s += sgl.gen(
|
||||
"json",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=json_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
class Weapon(str, Enum):
|
||||
sword = "sword"
|
||||
axe = "axe"
|
||||
mace = "mace"
|
||||
spear = "spear"
|
||||
bow = "bow"
|
||||
crossbow = "crossbow"
|
||||
|
||||
|
||||
class Armor(str, Enum):
|
||||
leather = "leather"
|
||||
chainmail = "chainmail"
|
||||
plate = "plate"
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
name: constr(max_length=10)
|
||||
age: int
|
||||
armor: Armor
|
||||
weapon: Weapon
|
||||
strength: int
|
||||
|
||||
|
||||
@sgl.function
|
||||
def character_gen(s):
|
||||
s += "Give me a character description who is a wizard.\n"
|
||||
s += sgl.gen(
|
||||
"character",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=build_regex_from_object(Character),
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
state = regex_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "IP TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = json_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "JSON TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = character_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
|
||||
# ==================== IP TEST ====================
|
||||
# Q: What is the IP address of the Google DNS servers?
|
||||
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
||||
# ==================== JSON TEST ====================
|
||||
# The information about Hogwarts is in the following JSON format.
|
||||
|
||||
# {
|
||||
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
||||
# "country": "Scotland",
|
||||
# "latitude": 55.566667,
|
||||
# "population": 1000,
|
||||
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
||||
# }
|
||||
|
||||
# ==================== CHARACTER TEST ====================
|
||||
# Give me a character description who is a wizard.
|
||||
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|
||||
132
scripts/deprecated/test_robust.py
Normal file
132
scripts/deprecated/test_robust.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
TOKENIZER = None
|
||||
RANDOM_PREFILL_LEN = None
|
||||
RANDOM_DECODE_LEN = None
|
||||
|
||||
|
||||
def gen_prompt(token_num):
|
||||
if RANDOM_PREFILL_LEN:
|
||||
token_num = random.randint(1, token_num)
|
||||
|
||||
cha_set = string.ascii_letters + string.digits
|
||||
ret = "".join(random.choices(cha_set, k=token_num))
|
||||
while len(TOKENIZER(ret).input_ids) < token_num:
|
||||
ret += random.choice(cha_set)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def robust_test_dfs(s, d, args, leaf_states):
|
||||
if d == 0:
|
||||
s += "END"
|
||||
leaf_states.append(s)
|
||||
return
|
||||
|
||||
s += gen_prompt(args.len_prefill)
|
||||
forks = s.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
for fork_s in forks:
|
||||
robust_test_dfs(fork_s, d - 1, args, leaf_states)
|
||||
|
||||
|
||||
def robust_test_bfs(s, args, leaf_states):
|
||||
old_forks = [s]
|
||||
new_forks = []
|
||||
for _ in range(args.depth):
|
||||
for old_fork in old_forks:
|
||||
old_fork += gen_prompt(args.len_prefill)
|
||||
forks = old_fork.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
new_forks.extend(forks)
|
||||
|
||||
old_forks = new_forks
|
||||
new_forks = []
|
||||
|
||||
for old_fork in old_forks:
|
||||
old_fork += "END"
|
||||
leaf_states.append(old_fork)
|
||||
|
||||
|
||||
@sgl.function
|
||||
def robust_test(s, args):
|
||||
leaf_states = []
|
||||
if args.mode == "bfs":
|
||||
robust_test_bfs(s, args, leaf_states)
|
||||
else:
|
||||
robust_test_dfs(s, args.depth, args, leaf_states)
|
||||
return leaf_states
|
||||
|
||||
|
||||
def main(args):
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
arguments = [{"args": args} for _ in range(args.num_req)]
|
||||
|
||||
states = robust_test.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel
|
||||
)
|
||||
|
||||
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
|
||||
for state in states:
|
||||
leaf_states = state.ret_value
|
||||
for leaf_state in leaf_states:
|
||||
assert leaf_state.text()[-3:] == "END"
|
||||
f.write(leaf_state.text()[:-3] + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-req", type=int, default=2)
|
||||
parser.add_argument("--depth", type=int, default=3)
|
||||
parser.add_argument("--num-fork", type=int, default=2)
|
||||
parser.add_argument("--len-prefill", type=int, default=128)
|
||||
parser.add_argument("--len-decode", type=int, default=128)
|
||||
parser.add_argument("--random-prefill-len", action="store_true")
|
||||
parser.add_argument("--random-decode-len", action="store_true")
|
||||
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
|
||||
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
# fmt: on
|
||||
|
||||
RANDOM_PREFILL_LEN = args.random_prefill_len
|
||||
RANDOM_DECODE_LEN = args.random_decode_len
|
||||
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
||||
115
scripts/export_deepseek_nextn.py
Normal file
115
scripts/export_deepseek_nextn.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
|
||||
|
||||
Usage:
|
||||
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def get_nextn_layer_id(config):
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
raise ValueError("'num_hidden_layers' not found in model config.")
|
||||
return config.num_hidden_layers
|
||||
|
||||
|
||||
def update_and_save_config(config, output_dir):
|
||||
new_config = config.to_dict()
|
||||
new_config.update(
|
||||
{
|
||||
"num_hidden_layers": 1,
|
||||
"architectures": ["DeepseekV3ForCausalLMNextN"],
|
||||
}
|
||||
)
|
||||
with open(os.path.join(output_dir, "config.json"), "w") as f:
|
||||
json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
|
||||
def copy_non_safetensors_files(input_dir, output_dir):
|
||||
for filename in os.listdir(input_dir):
|
||||
src_file_path = os.path.join(input_dir, filename)
|
||||
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
|
||||
dst_file_path = os.path.join(output_dir, filename)
|
||||
shutil.copy2(src_file_path, dst_file_path)
|
||||
print(f"All non-safetensors files have been copied to {output_dir}")
|
||||
|
||||
|
||||
def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
|
||||
prefix = f"model.layers.{nextn_layer_id}"
|
||||
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
|
||||
params = {}
|
||||
for filename in os.listdir(input_dir):
|
||||
if not filename.endswith(".safetensors"):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(input_dir, filename)
|
||||
print(f"Processing: {filename}")
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
|
||||
|
||||
if not matching_keys:
|
||||
print(f" No parameters starting with '{prefix}' found")
|
||||
continue
|
||||
|
||||
for key in matching_keys:
|
||||
if "embed_tokens" in key or "shared_head.head" in key:
|
||||
continue
|
||||
new_key = key.replace(prefix, "model.layers.0")
|
||||
params[new_key] = f.get_tensor(key)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error processing {filename}: {str(e)}")
|
||||
|
||||
if params:
|
||||
print(f"Saving {len(params)} parameters to {output_path}")
|
||||
save_file(params, output_path)
|
||||
else:
|
||||
print("No matching parameters found.")
|
||||
|
||||
# Update safetensors index
|
||||
index_path = os.path.join(output_dir, "model.safetensors.index.json")
|
||||
print(f"Updating safetensors index to {index_path}")
|
||||
index_data = {"weight_map": {}}
|
||||
for key in params:
|
||||
index_data["weight_map"][key] = "nextn_layer_parameters.safetensors"
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index_data, f, indent=4)
|
||||
|
||||
print("All done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export NextN layer parameters for DeepSeek-V3/R1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input HF model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output nextn model directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
|
||||
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
|
||||
nextn_layer_id = get_nextn_layer_id(config)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
copy_non_safetensors_files(args.input_dir, args.output_dir)
|
||||
update_and_save_config(config, args.output_dir)
|
||||
export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)
|
||||
32
scripts/killall_sglang.sh
Executable file
32
scripts/killall_sglang.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ "$1" = "rocm" ]; then
|
||||
echo "Running in ROCm mode"
|
||||
|
||||
# Clean SGLang processes
|
||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
|
||||
|
||||
else
|
||||
# Show current GPU status
|
||||
nvidia-smi
|
||||
|
||||
# Clean SGLang processes
|
||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
|
||||
|
||||
# Clean all GPU processes if any argument is provided
|
||||
if [ $# -gt 0 ]; then
|
||||
# Check if sudo is available
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y lsof
|
||||
else
|
||||
apt-get update
|
||||
apt-get install -y lsof
|
||||
fi
|
||||
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
||||
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
||||
fi
|
||||
|
||||
# Show GPU status after clean up
|
||||
nvidia-smi
|
||||
fi
|
||||
308
scripts/playground/bench_speculative.py
Normal file
308
scripts/playground/bench_speculative.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Usage:
|
||||
# single GPU
|
||||
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
|
||||
|
||||
# multiple GPU
|
||||
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.bench_serving import (
|
||||
DatasetRow,
|
||||
benchmark,
|
||||
sample_mmmu_requests,
|
||||
set_global_args,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
kill_process_tree,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def node0_print(msg):
|
||||
if server_args.node_rank == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
prompts = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
|
||||
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
|
||||
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
|
||||
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
|
||||
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
|
||||
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
|
||||
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
|
||||
]
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
|
||||
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||
if is_multimodal:
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_prompts,
|
||||
tokenizer,
|
||||
512,
|
||||
apply_chat_template=False,
|
||||
)
|
||||
backend = "sglang-oai-chat"
|
||||
api_url = f"{base_url}/v1/chat/completions"
|
||||
else:
|
||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||
:num_prompts
|
||||
]
|
||||
input_requests: List[DatasetRow] = [
|
||||
DatasetRow(p, 0, 512) for p in padded_prompts
|
||||
]
|
||||
backend = "sglang"
|
||||
api_url = f"{base_url}/generate"
|
||||
|
||||
# We need to set some dummy values in order to call `benchmark` below.
|
||||
args = SimpleNamespace(
|
||||
disable_ignore_eos=False,
|
||||
disable_stream=False,
|
||||
return_logprob=False,
|
||||
backend=backend,
|
||||
dataset_name="custom",
|
||||
num_prompts=None,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=None,
|
||||
random_output_len=None,
|
||||
random_range_ratio=None,
|
||||
output_file=None,
|
||||
warmup_requests=1,
|
||||
output_details=False,
|
||||
)
|
||||
set_global_args(args)
|
||||
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=float("inf"),
|
||||
max_concurrency=batch_size,
|
||||
disable_tqdm=False,
|
||||
lora_names=None,
|
||||
extra_request_body={},
|
||||
profile=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert results["completed"] == len(input_requests)
|
||||
acc_length = results["accept_length"] or 1.0
|
||||
avg_output_token = results["total_output_tokens"] / results["completed"]
|
||||
|
||||
server_info = requests.get(base_url + "/get_server_info").json()
|
||||
# We use 20% percentile instead of median on purpose
|
||||
step_time = np.percentile(
|
||||
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
|
||||
)
|
||||
speed = 1 / step_time * acc_length
|
||||
|
||||
return (
|
||||
round(acc_length, 3),
|
||||
round(step_time, 5),
|
||||
round(speed, 3),
|
||||
avg_output_token,
|
||||
)
|
||||
|
||||
|
||||
def main(args, server_args):
|
||||
base_url = "http://127.0.0.1:20000"
|
||||
|
||||
configs = []
|
||||
for batch_size in args.batch_size:
|
||||
for steps in args.steps:
|
||||
for topk in args.topk:
|
||||
for num_draft_tokens in args.num_draft_tokens:
|
||||
if steps * topk + 1 < num_draft_tokens:
|
||||
continue
|
||||
|
||||
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
|
||||
steps + topk + num_draft_tokens != 0
|
||||
):
|
||||
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
|
||||
continue
|
||||
|
||||
configs.append((batch_size, steps, topk, num_draft_tokens))
|
||||
|
||||
for i in range(args.start, args.end or len(configs)):
|
||||
batch_size, steps, topk, num_draft_tokens = configs[i]
|
||||
|
||||
node0_print(
|
||||
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
|
||||
)
|
||||
|
||||
# Create an LLM.
|
||||
if steps == 0:
|
||||
other_args = []
|
||||
else:
|
||||
other_args = [
|
||||
"--speculative-num-steps",
|
||||
steps,
|
||||
"--speculative-eagle-topk",
|
||||
topk,
|
||||
"--speculative-num-draft-tokens",
|
||||
num_draft_tokens,
|
||||
]
|
||||
if server_args.speculative_draft_model_path is not None:
|
||||
other_args.extend(
|
||||
[
|
||||
"--speculative-draft-model-path",
|
||||
server_args.speculative_draft_model_path,
|
||||
"--speculative-algorithm",
|
||||
server_args.speculative_algorithm,
|
||||
]
|
||||
)
|
||||
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
batch_size,
|
||||
"--mem-fraction-static",
|
||||
server_args.mem_fraction_static,
|
||||
"--tp-size",
|
||||
server_args.tp_size,
|
||||
"--max-running-requests",
|
||||
batch_size,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.trust_remote_code:
|
||||
other_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.attention_backend:
|
||||
other_args.extend(
|
||||
[
|
||||
"--attention-backend",
|
||||
server_args.attention_backend,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.quantization:
|
||||
other_args.extend(
|
||||
[
|
||||
"--quantization",
|
||||
server_args.quantization,
|
||||
]
|
||||
)
|
||||
|
||||
process = popen_launch_server(
|
||||
args.model_path,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
env={
|
||||
"SGLANG_RECORD_STEP_TIME": "1",
|
||||
**os.environ,
|
||||
},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
# Warmup
|
||||
send_one_batch(
|
||||
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||
base_url,
|
||||
max(args.num_prompts, batch_size),
|
||||
batch_size,
|
||||
tokenizer,
|
||||
args.is_multimodal,
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
node0_print(
|
||||
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
|
||||
)
|
||||
|
||||
record = {
|
||||
"batch_size": batch_size,
|
||||
"steps": steps,
|
||||
"topk": topk,
|
||||
"num_draft_tokens": num_draft_tokens,
|
||||
"acc_length": acc_length,
|
||||
"step_time": step_time,
|
||||
"speed": speed,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
with open(args.output, "a") as fout:
|
||||
fout.write(json.dumps(record) + "\n")
|
||||
|
||||
# Wait for the server to shutdown
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 2, 4, 8),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_draft_tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
|
||||
)
|
||||
parser.add_argument("--num-prompts", type=int, default=16)
|
||||
parser.add_argument("--start", type=int, default=0)
|
||||
parser.add_argument("--end", type=int)
|
||||
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||
parser.add_argument("--is-multimodal", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||
|
||||
main(args, server_args)
|
||||
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
@@ -0,0 +1,22 @@
|
||||
prompt = "The capital of france is "
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {"temperature": 0},
|
||||
"return_logprob": True,
|
||||
"return_input_logprob": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
j = response.json()
|
||||
input_logprobs = j["meta_info"]["input_token_logprobs"]
|
||||
output_logprobs = j["meta_info"]["output_token_logprobs"]
|
||||
|
||||
print(len(input_logprobs), len(output_logprobs))
|
||||
34
scripts/playground/disaggregation/cli-so.py
Normal file
34
scripts/playground/disaggregation/cli-so.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
port = 8000
|
||||
|
||||
json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
|
||||
# JSON
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json={
|
||||
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
|
||||
|
||||
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
|
||||
29
scripts/playground/disaggregation/cli.py
Normal file
29
scripts/playground/disaggregation/cli.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
prompt = """
|
||||
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
|
||||
|
||||
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
|
||||
|
||||
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
|
||||
|
||||
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
|
||||
|
||||
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
|
||||
|
||||
|
||||
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
|
||||
|
||||
Give your honest take on the above text:
|
||||
"""
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={"text": prompt, "sampling_params": {"temperature": 0}},
|
||||
)
|
||||
|
||||
|
||||
response_json = response.json()
|
||||
print(response_json["text"])
|
||||
241
scripts/playground/frontend_reasoning.ipynb
Normal file
241
scripts/playground/frontend_reasoning.ipynb
Normal file
@@ -0,0 +1,241 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Launch A Server\n",
|
||||
"\n",
|
||||
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
|
||||
"from sglang import assistant, function, gen, system, user\n",
|
||||
"from sglang import image\n",
|
||||
"from sglang import RuntimeEndpoint, set_default_backend\n",
|
||||
"from sglang.srt.utils import load_image\n",
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\")\n",
|
||||
"print(f\"Server started on http://localhost:{port}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"set_default_backend(\n",
|
||||
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += gen(\"answer\", max_tokens=512)\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(state[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa_separate_reasoning(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in multi-turn conversations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def multi_turn_qa(s):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = multi_turn_qa()\n",
|
||||
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using No thinking as Qwen 3's advanced feature \n",
|
||||
"\n",
|
||||
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
|
||||
"\n",
|
||||
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reasoning_state = basic_qa_separate_reasoning(\n",
|
||||
" \"List 3 countries and their capitals. /no_think\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
|
||||
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in regular expression generation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def regular_expression_gen(s):\n",
|
||||
" s += user(\n",
|
||||
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
|
||||
" )\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(\n",
|
||||
" gen(\n",
|
||||
" \"answer\",\n",
|
||||
" temperature=0,\n",
|
||||
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
|
||||
" max_tokens=512,\n",
|
||||
" ),\n",
|
||||
" model_type=\"qwen3\",\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = regular_expression_gen()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
7
scripts/playground/launch_tgi.sh
Normal file
7
scripts/playground/launch_tgi.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.1.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
14
scripts/playground/load_tokenizer.py
Normal file
14
scripts/playground/load_tokenizer.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
import code
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
t = get_tokenizer(args.name)
|
||||
code.interact(local=locals())
|
||||
36
scripts/playground/long_context_example.py
Normal file
36
scripts/playground/long_context_example.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from urllib.request import urlopen
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
test_cases = {
|
||||
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
|
||||
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
|
||||
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
|
||||
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
|
||||
}
|
||||
|
||||
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
|
||||
|
||||
for name, url in test_cases.items():
|
||||
print(f"\n==== Running test case: {name} ====")
|
||||
try:
|
||||
with urlopen(url, timeout=10) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Failed to load prompt for {name}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content is not None:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\nError during completion for {name}: {e}")
|
||||
77
scripts/playground/lora/analyzer.py
Normal file
77
scripts/playground/lora/analyzer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append("../../")
|
||||
from fix_corrupted_json import clean_json_file
|
||||
|
||||
dirpath = "/Users/ying"
|
||||
output_file_prefix = "analyzed_log"
|
||||
|
||||
time = {}
|
||||
tot_time = {}
|
||||
size = {}
|
||||
|
||||
os.system(f"rm {output_file_prefix}*")
|
||||
|
||||
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
|
||||
print(dirname)
|
||||
trace_name = dirname.split("/")[-1]
|
||||
time[trace_name] = {}
|
||||
size[trace_name] = {}
|
||||
total_time = 0
|
||||
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
|
||||
step_name = filename.split("/")[-1].split(".")[0]
|
||||
step_name = "_".join(step_name.split("_")[1:])
|
||||
if "prefill" not in filename and "decode" not in filename:
|
||||
continue
|
||||
|
||||
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
|
||||
if match:
|
||||
phase = match.group(1)
|
||||
step = match.group(2)
|
||||
else:
|
||||
raise Exception(f"Cannot parse {filename}")
|
||||
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
except:
|
||||
clean_json_file(filename, filename)
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
|
||||
for event in trace["traceEvents"]:
|
||||
name = event["name"]
|
||||
if name in ["profile_prefill_step", "profile_decode_step"]:
|
||||
dur = event["dur"] / 1e3
|
||||
time[trace_name][step_name] = dur
|
||||
break
|
||||
total_time += dur
|
||||
|
||||
step = int(step_name.split("_")[-1])
|
||||
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
|
||||
size_info = json.load(f)
|
||||
size[trace_name][step_name] = size_info["size"]
|
||||
|
||||
tot_time[trace_name] = total_time
|
||||
time[trace_name] = dict(
|
||||
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
size[trace_name] = dict(
|
||||
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
|
||||
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
|
||||
for k, v in time[trace_name].items():
|
||||
size_v = size[trace_name][k]
|
||||
print(f"{k:>15}{v:10.2f}\t{size_v}")
|
||||
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
|
||||
|
||||
with open(f"{output_file_prefix}_total_time", "w") as f:
|
||||
print(tot_time)
|
||||
json.dump(tot_time, f)
|
||||
62
scripts/playground/lora/lora_hf_play.py
Normal file
62
scripts/playground/lora/lora_hf_play.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
HF_TOKEN = "..."
|
||||
|
||||
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL,
|
||||
device_map="auto",
|
||||
# load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
# use_auth_token=HF_TOKEN,
|
||||
).cuda()
|
||||
|
||||
|
||||
# base model generate
|
||||
with torch.no_grad():
|
||||
output_tensors = base_model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= base output ========")
|
||||
print(output)
|
||||
|
||||
|
||||
# peft model generate
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
ADAPTER,
|
||||
torch_dtype=torch.float16,
|
||||
is_trainable=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= peft output ========")
|
||||
print(output)
|
||||
30
scripts/playground/lora/lora_vllm_play.py
Normal file
30
scripts/playground/lora/lora_vllm_play.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
llm = LLM(model=MODEL, enable_lora=True)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
|
||||
)
|
||||
|
||||
print(outputs[0].prompt)
|
||||
print(outputs[0].outputs[0].text)
|
||||
197
scripts/playground/reference_hf.py
Normal file
197
scripts/playground/reference_hf.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
|
||||
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
|
||||
--model-type {text,vlm}: Model type, text or vlm (default: text)
|
||||
--max-new-tokens NUM: Max new tokens to generate (default: 16)
|
||||
--dtype DTYPE: Data type for computation (default: float16)
|
||||
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
|
||||
|
||||
Reference output:
|
||||
========== Prompt 0 ==========
|
||||
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
|
||||
device='cuda:0')
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
========== Prompt 1 ==========
|
||||
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
|
||||
device='cuda:0')
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of
|
||||
|
||||
========== Prompt 2 ==========
|
||||
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vlm_text_with_image(args):
|
||||
# Load the processor and model for ImageTextToText tasks
|
||||
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# List of image URLs to process
|
||||
image_urls = [
|
||||
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
]
|
||||
|
||||
# Conversation template for the processor
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
for i, url in enumerate(image_urls):
|
||||
# Load the image from the URL
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Apply the chat template to the text prompt
|
||||
# Notice that not all processors support chat templates.
|
||||
# LLaVA and QWen are two processors that support chat templates.
|
||||
if not hasattr(processor, "apply_chat_template"):
|
||||
raise ValueError("The processor does not support chat templates.")
|
||||
text_prompt = processor.apply_chat_template(
|
||||
conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Prepare inputs for the model
|
||||
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
|
||||
"cuda:0"
|
||||
)
|
||||
|
||||
# Generate output from the model
|
||||
output_ids = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = processor.decode(output_ids[0])
|
||||
|
||||
# Get the logits from the model's forward pass
|
||||
outputs = model.forward(**inputs)
|
||||
logits = outputs.logits[0, -1, :]
|
||||
|
||||
print(f"\n========== Image {i} ==========")
|
||||
print("prefill logits (final)", logits)
|
||||
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
|
||||
# making it cluttered and difficult to read.
|
||||
# These tokens should be removed or cleaned up for better readability.
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def normal_text(args):
|
||||
t = get_tokenizer(args.model_path, trust_remote_code=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda:0")
|
||||
|
||||
output_ids = m.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print(f"\n========== Prompt {i} ==========")
|
||||
print("prefill logits (final)", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def synthetic_tokens(args):
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
print(m)
|
||||
|
||||
input_len = 256
|
||||
output_len = 8
|
||||
prompts = [list(range(5, 5 + input_len))]
|
||||
|
||||
for p in prompts:
|
||||
input_ids = p
|
||||
for i in range(output_len + 1):
|
||||
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
|
||||
0
|
||||
][-1]
|
||||
|
||||
if i == 0:
|
||||
print("prefill logits", prefill_logits)
|
||||
else:
|
||||
print("decode", i - 1, prefill_logits)
|
||||
|
||||
input_ids.append(torch.argmax(prefill_logits).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
|
||||
)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=16)
|
||||
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
|
||||
parser.add_argument("--model-type", type=str, default="text")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "vlm":
|
||||
vlm_text_with_image(args)
|
||||
else:
|
||||
normal_text(args)
|
||||
151
scripts/playground/replay_request_dump.py
Normal file
151
scripts/playground/replay_request_dump.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Usage:
|
||||
# replay from a folder
|
||||
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
|
||||
|
||||
# replay from a single file
|
||||
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.bench_serving import set_ulimit
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
def read_records(files):
|
||||
records = []
|
||||
for f in files:
|
||||
tmp = pickle.load(open(f, "rb"))
|
||||
if isinstance(tmp, dict) and "requests" in tmp:
|
||||
records.extend(tmp["requests"])
|
||||
else:
|
||||
records.extend(tmp)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def run_one_request_internal(record):
|
||||
(req, output, replay_init_time, start_time, end_time, idx) = record
|
||||
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
||||
|
||||
if "completion_tokens" in output.get("meta_info", {}):
|
||||
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
||||
else:
|
||||
recorded_completion_tokens = ""
|
||||
|
||||
json_data = asdict(req)
|
||||
stream = json_data["stream"]
|
||||
|
||||
if args.ignore_eos:
|
||||
json_data["sampling_params"]["ignore_eos"] = True
|
||||
if recorded_completion_tokens:
|
||||
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
ret = json.loads(chunk[5:].strip("\n"))
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
print(
|
||||
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
|
||||
f"{completion_tokens=}, {recorded_completion_tokens=}"
|
||||
)
|
||||
|
||||
|
||||
def run_one_request(record):
|
||||
# global success_ct, error_ct
|
||||
|
||||
try:
|
||||
run_one_request_internal(record)
|
||||
# success_ct += 1
|
||||
except Exception:
|
||||
# error_ct += 1
|
||||
traceback = get_exception_traceback()
|
||||
print(f"Hit an exception: {traceback}")
|
||||
|
||||
|
||||
def main(records):
|
||||
if len(records) == 0:
|
||||
return
|
||||
|
||||
base_time = records[0][-2]
|
||||
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
|
||||
print(f"{base_time_str=}")
|
||||
replay_init_time = time.time()
|
||||
|
||||
for i in range(len(records)):
|
||||
req, output, start_time, end_time = records[i]
|
||||
start_time -= base_time
|
||||
records[i] = (req, output, replay_init_time, start_time, end_time, i)
|
||||
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(run_one_request, records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
parser.add_argument(
|
||||
"--input-folder", type=str, default=None, help="Folder containing pickle files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file", type=str, default=None, help="Single pickle file to process"
|
||||
)
|
||||
parser.add_argument("--file-number", type=int, default=1)
|
||||
parser.add_argument("--req-number", type=int, default=1000000)
|
||||
parser.add_argument("--req-start", type=int, default=0)
|
||||
parser.add_argument("--parallel", type=int, default=512)
|
||||
parser.add_argument("--idx", type=int, default=None)
|
||||
parser.add_argument("--ignore-eos", action="store_true")
|
||||
parser.add_argument("--speed", type=float, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
set_ulimit()
|
||||
|
||||
files = []
|
||||
if args.input_file:
|
||||
files = [args.input_file]
|
||||
if args.file_number > 1:
|
||||
print("Warning: --file-number is ignored when --input-file is provided.")
|
||||
elif args.input_folder:
|
||||
files = glob.glob(f"{args.input_folder}/*.pkl")
|
||||
files = files[: args.file_number]
|
||||
else:
|
||||
print("Error: Either --input-folder or --input-file must be provided.")
|
||||
exit(1)
|
||||
print(f"{files=}")
|
||||
|
||||
records = read_records(files)
|
||||
# Sort by the receive time, before filtering
|
||||
records.sort(key=lambda x: x[-2])
|
||||
records = records[args.req_start :]
|
||||
if args.idx:
|
||||
records = [records[args.idx]]
|
||||
print(f"testing {args.idx=}")
|
||||
print(f"{records[0]}")
|
||||
print(f"{len(records)=}")
|
||||
main(records)
|
||||
207
scripts/playground/router/test_tree.py
Normal file
207
scripts/playground/router/test_tree.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tree import MultiTenantRadixTree
|
||||
|
||||
|
||||
class TestMultiTenantRadixTree(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tree = MultiTenantRadixTree()
|
||||
|
||||
def test_insert_exact_match(self):
|
||||
"""Test 1: Basic insert and exact match operations"""
|
||||
# Insert a single string for one tenant
|
||||
self.tree.insert("hello", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Insert same string for different tenant
|
||||
self.tree.insert("hello", "tenant2")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Insert different string for same tenant
|
||||
self.tree.insert("world", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
def test_insert_partial_match(self):
|
||||
"""Test 2: Insert with partial matching scenarios"""
|
||||
# Test partial matches with common prefixes
|
||||
self.tree.insert("hello", "tenant1")
|
||||
print(self.tree.pretty_print())
|
||||
self.tree.insert("help", "tenant2")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Match exact strings
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
matched, tenant = self.tree.prefix_match("help")
|
||||
self.assertEqual(matched, "help")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
# Match partial string
|
||||
matched, tenant = self.tree.prefix_match("hel")
|
||||
self.assertEqual(matched, "hel")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Match longer string
|
||||
matched, tenant = self.tree.prefix_match("hello_world")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_insert_edge_cases(self):
|
||||
"""Test 3: Edge cases for insert and match operations"""
|
||||
# Empty string
|
||||
self.tree.insert("", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("")
|
||||
self.assertEqual(matched, "")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Single character
|
||||
self.tree.insert("a", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("a")
|
||||
self.assertEqual(matched, "a")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Very long string
|
||||
long_str = "a" * 1000
|
||||
self.tree.insert(long_str, "tenant1")
|
||||
matched, tenant = self.tree.prefix_match(long_str)
|
||||
self.assertEqual(matched, long_str)
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Unicode characters
|
||||
self.tree.insert("你好", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("你好")
|
||||
self.assertEqual(matched, "你好")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_simple_eviction(self):
|
||||
"""Test 4: Simple eviction scenarios
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 5 chars
|
||||
|
||||
Should demonstrate:
|
||||
1. Basic eviction when size limit exceeded
|
||||
2. Proper eviction based on last access time
|
||||
3. Verification that shared nodes remain intact for other tenants
|
||||
"""
|
||||
# Set up size limits
|
||||
max_size = {"tenant1": 10, "tenant2": 5}
|
||||
|
||||
# Insert strings for both tenants
|
||||
self.tree.insert("hello", "tenant1") # size 5
|
||||
self.tree.insert("hello", "tenant2") # size 5
|
||||
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
|
||||
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
|
||||
|
||||
# Evict - should remove "hello" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
|
||||
|
||||
# Verify "world" remains for tenant2 (was accessed more recently)
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
def test_medium_eviction(self):
|
||||
"""Test 5: Medium complexity eviction scenarios with shared prefixes
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 7 chars (forces one string to be evicted)
|
||||
|
||||
Tree structure after inserts:
|
||||
└── 'h' [t1, t2]
|
||||
├── 'i' [t1, t2] # Oldest for t2
|
||||
└── 'e' [t1, t2]
|
||||
├── 'llo' [t1, t2]
|
||||
└── 'y' [t2] # Newest for t2
|
||||
|
||||
Size calculations:
|
||||
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
|
||||
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
|
||||
|
||||
After eviction (tenant2 exceeds limit by 1 char):
|
||||
"hi" should be removed from tenant2 as it's the oldest access
|
||||
"""
|
||||
max_size = {
|
||||
"tenant1": 10,
|
||||
"tenant2": 6,
|
||||
} # tenant2 will need to evict one string
|
||||
|
||||
# Create a tree with overlapping prefixes
|
||||
self.tree.insert("hi", "tenant1")
|
||||
self.tree.insert("hi", "tenant2") # OLDEST for t2
|
||||
|
||||
self.tree.insert("hello", "tenant1")
|
||||
self.tree.insert("hello", "tenant2")
|
||||
|
||||
self.tree.insert("hey", "tenant2") # NEWEST for t2
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
|
||||
self.assertEqual(
|
||||
sizes_before["tenant2"], 7
|
||||
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
|
||||
|
||||
print("\nTree before eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Evict - should remove "hi" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
print("\nTree after eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
|
||||
|
||||
def test_advanced_eviction(self):
|
||||
...
|
||||
# Create 4 tenants
|
||||
# Each tenants keeps adding strings with shared prefixes to thousands usage
|
||||
# Set a strict limit for each tenant to only 100
|
||||
# At the end, check whether all of the tenant is under 100 after eviction
|
||||
|
||||
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
|
||||
|
||||
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
|
||||
for i in range(100):
|
||||
for j, prefix in enumerate(prefixes):
|
||||
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
|
||||
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
|
||||
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_before)
|
||||
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_after)
|
||||
# ensure size_after is below max_size
|
||||
for tenant, size in sizes_after.items():
|
||||
self.assertLessEqual(size, max_size[tenant])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
292
scripts/playground/router/tree.py
Normal file
292
scripts/playground/router/tree.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self):
|
||||
self.children: Dict[str, Node] = dict()
|
||||
# We choose to use text because most of the use cases are text-to-text,
|
||||
# so we can save the tokenizing overhead.
|
||||
self.text: str = ""
|
||||
# Maps tenant_id to their last access timestamp
|
||||
self.tenant_last_access_time: Dict[str, float] = dict()
|
||||
self.parent = None
|
||||
|
||||
|
||||
def shared_prefix_length(s1, s2):
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(min_length):
|
||||
if s1[i] != s2[i]:
|
||||
return i
|
||||
return min_length
|
||||
|
||||
|
||||
class MultiTenantRadixTree:
|
||||
"""
|
||||
Python Reference of Rust implementation of MultiTenantRadixTree
|
||||
|
||||
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
|
||||
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
|
||||
while maintaining tenant isolation.
|
||||
|
||||
Key concepts:
|
||||
- Tenant: An entity that owns a subset of the stored strings
|
||||
- Each node tracks which tenants have access to it via tenant_last_access_time
|
||||
- The tree structure is shared, but queries can be filtered by tenant_id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.root = Node()
|
||||
|
||||
def insert(self, s: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Insert string 's' and associate it with the given tenant_id.
|
||||
|
||||
Args:
|
||||
s: The string to insert
|
||||
tenant_id: The identifier of the tenant who owns this string
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
# No match => create a new node
|
||||
new_node = Node()
|
||||
new_node.text = s[curr_idx:]
|
||||
new_node.parent = curr
|
||||
|
||||
curr.children[s[curr_idx]] = new_node
|
||||
curr_idx = len(s)
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
else:
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
|
||||
# 1. If the matched text is shorter than the node text => split the node
|
||||
if shared_len < len(matched_node.text):
|
||||
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
|
||||
|
||||
matched_text = matched_node.text[:shared_len]
|
||||
unmatched_text = matched_node.text[shared_len:]
|
||||
|
||||
new_node = Node()
|
||||
new_node.text = matched_text
|
||||
new_node.children = {unmatched_text[0]: matched_node}
|
||||
new_node.parent = curr
|
||||
new_node.parent.children[matched_text[0]] = new_node
|
||||
new_node.tenant_last_access_time = (
|
||||
matched_node.tenant_last_access_time.copy()
|
||||
)
|
||||
|
||||
# Contract matched node
|
||||
matched_node.text = unmatched_text
|
||||
matched_node.parent = new_node
|
||||
|
||||
curr_idx += shared_len
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
# 2. If the matched text is longer or equal to the node text => walk down the node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
def prefix_match(self, s: str) -> tuple[str, int]:
|
||||
"""
|
||||
Match string 's' with multiple tenants' trees in one operation.
|
||||
|
||||
Args:
|
||||
s: The string to match
|
||||
|
||||
Returns:
|
||||
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
|
||||
ret_text = ""
|
||||
ret_tenant = None
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
break
|
||||
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
if shared_len == len(matched_node.text):
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
break
|
||||
|
||||
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
|
||||
|
||||
# traverse back to the root to update last access time for the selected tenant
|
||||
while curr != self.root:
|
||||
curr.tenant_last_access_time[selected_tenant] = time.time()
|
||||
curr = curr.parent
|
||||
|
||||
return s[:curr_idx], selected_tenant
|
||||
|
||||
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
|
||||
"""
|
||||
Evict data for tenants that have exceeded their storage limits.
|
||||
|
||||
Args:
|
||||
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
|
||||
"""
|
||||
|
||||
def leaf_of(node):
|
||||
"""
|
||||
If the node is a leaf for a tenant, add tenant_id to the return list
|
||||
This will return list of tenant ids
|
||||
If not a leaf for all tenants, return []
|
||||
"""
|
||||
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
|
||||
|
||||
for n in node.children.values():
|
||||
for c in n.tenant_last_access_time.keys():
|
||||
candidates[c] = False
|
||||
|
||||
return [k for k, v in candidates.items() if v]
|
||||
|
||||
# maintain a heap with (time, tenant, node) as the value
|
||||
import heapq
|
||||
|
||||
# 1. traverse the tree to
|
||||
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
|
||||
# b. calculate the used size for each tenant
|
||||
# do a dfs with stack
|
||||
stack = [self.root]
|
||||
pq = []
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
# if the node is a leaf for a tenant, add the tenant to the heap
|
||||
tenants = leaf_of(curr)
|
||||
for t in tenants:
|
||||
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
|
||||
|
||||
# 2. pop the heap
|
||||
# a. if the tenant's used size is less than the limit, continue
|
||||
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
|
||||
while len(pq) > 0:
|
||||
time, tenant, node = heapq.heappop(pq)
|
||||
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
|
||||
continue
|
||||
|
||||
# remove the leaf
|
||||
used_size_per_tenant[tenant] -= len(node.text)
|
||||
del node.tenant_last_access_time[tenant]
|
||||
# if no children and no tenants, remove the node
|
||||
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
|
||||
del node.parent.children[node.text[0]]
|
||||
|
||||
# add its parent to the heap
|
||||
if tenant in leaf_of(node.parent):
|
||||
heapq.heappush(
|
||||
pq,
|
||||
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
|
||||
)
|
||||
|
||||
def get_used_size_per_tenant(self) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the used storage size for each tenant.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
|
||||
"""
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
stack = [self.root]
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
return used_size_per_tenant
|
||||
|
||||
def remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Remove all data associated with a specific tenant from the tree.
|
||||
This operation maintains the integrity of the shared tree structure while
|
||||
removing only the specified tenant's access information.
|
||||
|
||||
Args:
|
||||
tenant_id: The identifier of the tenant whose data should be removed
|
||||
"""
|
||||
# TODO: Implementation needed
|
||||
pass
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the tree showing the structure, tenant ownership,
|
||||
and leaf status for each node.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the tree hierarchy with tenant information
|
||||
"""
|
||||
|
||||
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
|
||||
# Current node representation
|
||||
node_str = prefix
|
||||
node_str += "└── " if is_last else "├── "
|
||||
|
||||
# Add node text
|
||||
node_str += f"'{node.text}' ["
|
||||
|
||||
# Add tenant information including both timestamp and leaf status
|
||||
tenant_info = []
|
||||
for tid, ts in node.tenant_last_access_time.items():
|
||||
time_str = (
|
||||
time.strftime("%H:%M:%S.", time.localtime(ts))
|
||||
+ f"{(ts % 1):0.3f}"[2:]
|
||||
)
|
||||
tenant_info.append(f"{tid} | {time_str}")
|
||||
|
||||
node_str += ", ".join(tenant_info)
|
||||
node_str += "]\n"
|
||||
|
||||
# Handle children
|
||||
children = list(node.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last_child = i == len(children) - 1
|
||||
# Adjust prefix for children based on whether this is the last child
|
||||
new_prefix = prefix + (" " if is_last else "│ ")
|
||||
node_str += _node_to_str(child, new_prefix, is_last_child)
|
||||
|
||||
return node_str
|
||||
|
||||
if not self.root.children:
|
||||
return "Empty tree"
|
||||
|
||||
# Start with root's children since root itself is just an empty node
|
||||
result = ""
|
||||
children = list(self.root.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last = i == len(children) - 1
|
||||
result += _node_to_str(child, "", is_last)
|
||||
|
||||
return result
|
||||
33
scripts/update_kernel_whl_index.py
Normal file
33
scripts/update_kernel_whl_index.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
|
||||
def update_wheel_index(cuda_version="118"):
|
||||
index_dir = pathlib.Path(f"sgl-whl/cu{cuda_version}/sgl-kernel")
|
||||
index_dir.mkdir(exist_ok=True)
|
||||
base_url = "https://github.com/sgl-project/whl/releases/download"
|
||||
|
||||
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
|
||||
with open(path, "rb") as f:
|
||||
sha256 = hashlib.sha256(f.read()).hexdigest()
|
||||
ver = re.findall(
|
||||
r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+cu[0-9]+)?-", path.name
|
||||
)[0]
|
||||
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
|
||||
with (index_dir / "index.html").open("a") as f:
|
||||
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cuda", type=str, default="118")
|
||||
args = parser.parse_args()
|
||||
update_wheel_index(args.cuda)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
scripts/version_branch_to_tag.sh
Executable file
30
scripts/version_branch_to_tag.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# This script is used for release.
|
||||
# It tags all remote branches starting with 'v' with the same name as the branch,
|
||||
# deletes the corresponding branches from the remote, and pushes the tags to the remote repository.
|
||||
|
||||
git fetch origin --prune
|
||||
|
||||
# List all branches starting with 'v'
|
||||
branches=$(git branch -r | grep 'origin/v' | sed 's/origin\///')
|
||||
|
||||
# Loop through each branch
|
||||
for branch in $branches; do
|
||||
echo "Processing branch: $branch"
|
||||
|
||||
# Get the commit hash for the branch
|
||||
commit_hash=$(git rev-parse origin/$branch)
|
||||
|
||||
# Create a tag with the same name as the branch using the commit hash
|
||||
git tag $branch $commit_hash
|
||||
|
||||
# Delete the branch from the remote
|
||||
git push origin --delete $branch
|
||||
done
|
||||
|
||||
# Push all tags to the remote repository
|
||||
git push --tags
|
||||
|
||||
echo "All branches starting with 'v' have been tagged, deleted from remote, and pushed to the remote repository."
|
||||
Reference in New Issue
Block a user