sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

60
scripts/ci/amd_ci_exec.sh Executable file
View File

@@ -0,0 +1,60 @@
#!/bin/bash
set -euo pipefail
# Detect GPU family from hostname (e.g., linux-mi35x-gpu-1-xxxxx-runner-zzzzz)
HOSTNAME_VALUE=$(hostname)
GPU_FAMILY=""
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_FAMILY="${BASH_REMATCH[1]}"
echo "Detected GPU family from hostname: ${GPU_FAMILY}"
else
echo "Warning: could not parse GPU family from '${HOSTNAME_VALUE}'"
fi
WORKDIR="/sglang-checkout/test/srt"
declare -A ENV_MAP=(
[SGLANG_AMD_CI]=1
[SGLANG_IS_IN_CI]=1
[SGLANG_USE_AITER]=1
)
# Conditionally add GPU_ARCHS only for mi35x
if [[ "${GPU_FAMILY}" == "mi35x" ]]; then
ENV_MAP[GPU_ARCHS]="gfx950"
fi
# Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do
case "$1" in
-w|--workdir)
WORKDIR="$2"
shift 2
;;
-e)
IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2
;;
--)
shift
break
;;
*)
break
;;
esac
done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec
docker exec \
-w "$WORKDIR" \
"${ENV_ARGS[@]}" \
ci_sglang "$@"

View File

@@ -0,0 +1,47 @@
#!/bin/bash
set -euo pipefail
HOSTNAME_VALUE=$(hostname)
GPU_ARCH="mi30x" # default
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_ARCH="${BASH_REMATCH[1]}"
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
else
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
fi
# Install the required dependencies in CI.
docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
case "${GPU_ARCH}" in
mi35x)
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
docker exec ci_sglang pip install -e "python[dev_hip]" --no-deps # TODO: only for mi35x
# For lmms_evals evaluating MMMU
docker exec -w / ci_sglang git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
docker exec -w /lmms-eval ci_sglang pip install -e . --no-deps # TODO: only for mi35x
;;
mi30x|mi300|mi325)
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
docker exec ci_sglang pip install -e "python[dev_hip]"
# For lmms_evals evaluating MMMU
docker exec -w / ci_sglang git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
docker exec -w /lmms-eval ci_sglang pip install -e .
;;
*)
echo "Runner architecture '${GPU_ARCH}' unrecognised;" >&2
;;
esac
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
docker exec -w /human-eval ci_sglang pip install -e .
docker exec -w / ci_sglang mkdir -p /dummy-grok
mkdir -p dummy-grok && wget https://sharkpublic.blob.core.windows.net/sharkpublic/sglang/dummy_grok.json -O dummy-grok/config.json
docker cp ./dummy-grok ci_sglang:/
docker exec ci_sglang pip install huggingface_hub[hf_xet]
docker exec ci_sglang pip install pytest

View File

@@ -0,0 +1,132 @@
#!/bin/bash
set -euo pipefail
# Get version from SGLang version.py file
SGLANG_VERSION_FILE="$(dirname "$0")/../../python/sglang/version.py"
SGLANG_VERSION="v0.5.0rc0" # Default version, will be overridden if version.py is found
if [ -f "$SGLANG_VERSION_FILE" ]; then
VERSION_FROM_FILE=$(python3 -c '
import re, sys
with open(sys.argv[1], "r") as f:
content = f.read()
match = re.search(r"__version__\s*=\s*[\"'"'"'](.*?)[\"'"'"']", content)
if match:
print("v" + match.group(1))
' "$SGLANG_VERSION_FILE" 2>/dev/null || echo "")
if [ -n "$VERSION_FROM_FILE" ]; then
SGLANG_VERSION="$VERSION_FROM_FILE"
echo "Using SGLang version from version.py: $SGLANG_VERSION"
else
echo "Warning: Could not parse version from $SGLANG_VERSION_FILE, using default: $SGLANG_VERSION" >&2
fi
else
echo "Warning: version.py not found, using default version: $SGLANG_VERSION" >&2
fi
# Default base tags (can be overridden by command line arguments)
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-rocm630-mi30x"
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi35x"
# Parse command line arguments
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
while [[ $# -gt 0 ]]; do
case $1 in
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
-h|--help)
echo "Usage: $0 [--mi30x-base-tag TAG] [--mi35x-base-tag TAG]"
exit 0
;;
*) echo "Unknown option $1"; exit 1;;
esac
done
# Detect GPU architecture from the Kubernetes runner hostname
HOSTNAME_VALUE=$(hostname)
GPU_ARCH="mi30x" # default
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_ARCH="${BASH_REMATCH[1]}"
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
else
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
fi
# Normalise / collapse architectures we dont yet build specifically for
case "${GPU_ARCH}" in
mi35x)
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
;;
mi30x|mi300|mi325)
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
GPU_ARCH="mi30x"
;;
*)
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
GPU_ARCH="mi30x"
;;
esac
# Set up DEVICE_FLAG based on Kubernetes pod info
if [[ -f /etc/podinfo/gha-render-devices ]]; then
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
else
DEVICE_FLAG="--device /dev/dri"
fi
# Find the latest image
find_latest_image() {
local gpu_arch=$1
local base_tag days_back image_tag
case "${gpu_arch}" in
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
esac
for days_back in {0..6}; do
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
echo "rocm/sgl-dev:${image_tag}"
return 0
fi
done
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
echo "Using hard-coded fallback…" >&2
if [[ "${gpu_arch}" == "mi35x" ]]; then
echo "rocm/sgl-dev:v0.5.0rc0-rocm700-mi35x-20250812"
else
echo "rocm/sgl-dev:v0.5.0rc0-rocm630-mi30x-20250812"
fi
}
# Pull and run the latest image
IMAGE=$(find_latest_image "${GPU_ARCH}")
echo "Pulling Docker image: ${IMAGE}"
docker pull "${IMAGE}"
echo "Launching container: ci_sglang"
docker run -dt --user root --device=/dev/kfd ${DEVICE_FLAG} \
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
--ipc=host --group-add video \
--shm-size 32g \
--cap-add=SYS_PTRACE \
-e HF_TOKEN="${HF_TOKEN:-}" \
--security-opt seccomp=unconfined \
-w /sglang-checkout \
--name ci_sglang \
"${IMAGE}"

68
scripts/ci/ci_install_deepep.sh Executable file
View File

@@ -0,0 +1,68 @@
#!/bin/bash
# Install the dependency in CI.
set -euxo pipefail
bash scripts/ci/ci_install_dependency.sh
export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
export NVSHMEM_DIR=/opt/nvshmem/install
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH"
export CUDA_HOME=/usr/local/cuda
if python3 -c "import deep_ep" >/dev/null 2>&1; then
echo "deep_ep is already installed or importable. Skipping installation."
exit 0
fi
# Install system dependencies
apt install -y curl wget git sudo libibverbs-dev rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake
# Install GDRCopy
rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy
rm -rf /opt/nvshmem && mkdir -p /opt/nvshmem
cd /opt/gdrcopy
git clone https://github.com/NVIDIA/gdrcopy.git .
git checkout v2.4.4
apt update
apt install -y nvidia-dkms-535
apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms
apt install -y check libsubunit0 libsubunit-dev python3-venv
cd packages
CUDA=/usr/local/cuda ./build-deb-packages.sh
dpkg -i gdrdrv-dkms_*.deb
dpkg -i libgdrapi_*.deb
dpkg -i gdrcopy-tests_*.deb
dpkg -i gdrcopy_*.deb
if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then
ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
fi
apt-get update && apt-get install -y libfabric-dev
# Install NVSHMEM
cd /opt/nvshmem
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz
tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz
mv nvshmem_src nvshmem && cd nvshmem
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90
cd build
make -j$(nproc) install
# Install DeepEP
rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout b92d0d4860ce6866cd6d31bfbae937f9a7a3772b
cd /root/.cache/deepep && python3 setup.py install
# Verify configuration
echo "=== Verify GDRCOPY ==="
gdrcopy_copybw
echo "=== Verify NVSHMEM ==="
nvshmem-info -a

View File

@@ -0,0 +1,81 @@
#!/bin/bash
# Install the dependency in CI.
set -euxo pipefail
IS_BLACKWELL=${IS_BLACKWELL:-0}
if [ "$IS_BLACKWELL" = "1" ]; then
CU_VERSION="cu129"
else
CU_VERSION="cu126"
fi
# Clear torch compilation cache
python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)'
# Kill existing processes
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
bash "${SCRIPT_DIR}/../killall_sglang.sh"
# Install apt packages
apt install -y git libnuma-dev
# Install uv
if [ "$IS_BLACKWELL" = "1" ]; then
# The blackwell CI runner has some issues with pip and uv,
# so we can only use pip with `--break-system-packages`
PIP_CMD="pip"
PIP_INSTALL_SUFFIX="--break-system-packages"
# Clean up existing installations
$PIP_CMD uninstall -y flashinfer_python sgl-kernel sglang vllm $PIP_INSTALL_SUFFIX || true
else
# In normal cases, we use uv, which is much faster than pip.
pip install --upgrade pip
pip install uv
export UV_SYSTEM_PYTHON=true
PIP_CMD="uv pip"
PIP_INSTALL_SUFFIX="--index-strategy unsafe-best-match"
# Clean up existing installations
$PIP_CMD uninstall flashinfer_python sgl-kernel sglang vllm || true
fi
# Install the main package
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
# Install router for pd-disagg test
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
SGL_KERNEL_VERSION=0.3.9.post2
if [ "$IS_BLACKWELL" = "1" ]; then
# TODO auto determine sgl-kernel version
$PIP_CMD install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall $PIP_INSTALL_SUFFIX
else
$PIP_CMD install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu124-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall $PIP_INSTALL_SUFFIX
fi
# Show current packages
$PIP_CMD list
# Install additional dependencies
$PIP_CMD install mooncake-transfer-engine==0.3.5 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" != "1" ]; then
# For lmms_evals evaluating MMMU
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
$PIP_CMD install -e lmms-eval/ $PIP_INSTALL_SUFFIX
# Install xformers
$PIP_CMD install xformers --index-url https://download.pytorch.org/whl/${CU_VERSION} --no-deps $PIP_INSTALL_SUFFIX
fi
# Install FlashMLA for attention backend tests
# $PIP_CMD install git+https://github.com/deepseek-ai/FlashMLA.git $PIP_INSTALL_SUFFIX
# Show current packages
$PIP_CMD list
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}"

24
scripts/ci/ci_install_rust.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/bin/bash
set -euxo pipefail
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y libssl-dev pkg-config protobuf-compiler
else
apt-get update
apt-get install -y libssl-dev pkg-config protobuf-compiler
fi
# Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# Follow the installation prompts, then reload your shell
. "$HOME/.cargo/env"
source $HOME/.cargo/env
# Verify installation
rustc --version
cargo --version
protoc --version

View File

@@ -0,0 +1,94 @@
#!/bin/bash
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
# Function to find the first available active IB device
find_active_ib_device() {
for device in mlx5_{0..11}; do
if ibv_devinfo $device >/dev/null 2>&1; then
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
if [[ "$state" == "PORT_ACTIVE" ]]; then
echo "$device"
return 0
fi
fi
done
echo "No active IB device found" >&2
return 1
}
# Get the first available active IB device
DEVICE=$(find_active_ib_device)
echo "Using IB device: $DEVICE"
# Launch prefill servers on GPU 03
for i in {0..3}; do
PORT=$((30001 + i))
BOOTSTRAP_PORT=$((9001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode prefill \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
done
# Launch decode servers on GPU 47
for i in {4..7}; do
PORT=$((30001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode decode \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--base-gpu-id 0 &
done
# Wait for disaggregation servers to initialize
echo "Waiting for disaggregation servers to initialize..."
# Health check with 5-minute timeout
TIMEOUT=300
START_TIME=$(date +%s)
echo "Checking health of all 8 servers..."
while true; do
CURRENT_TIME=$(date +%s)
ELAPSED=$((CURRENT_TIME - START_TIME))
if [ $ELAPSED -ge $TIMEOUT ]; then
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
exit 1
fi
HEALTHY_COUNT=0
# Check all 8 servers (127.0.0.1-8:30001-30008)
for i in {1..8}; do
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
fi
done
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
if [ $HEALTHY_COUNT -eq 8 ]; then
echo "✅ All 8 servers are healthy!"
break
else
sleep 10 # Wait 10 seconds before next check
fi
done
# Don't launch router here - just keep servers running
echo "✅ All disaggregation servers are ready and waiting for router connections"
# Keep the script running
wait

View File

@@ -0,0 +1,61 @@
#!/bin/bash
set -euo pipefail
PIP_INSTALL="pip install --no-cache-dir"
# Install the required dependencies in CI.
apt update -y && apt install -y \
build-essential \
cmake \
wget \
curl \
net-tools \
zlib1g-dev \
lld \
clang \
locales \
ccache \
ca-certificates
update-ca-certificates
python3 -m ${PIP_INSTALL} --upgrade pip
### Download MemFabricV2
MF_WHL_NAME="mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl"
MEMFABRIC_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${MF_WHL_NAME}"
wget -O "${MF_WHL_NAME}" "${MEMFABRIC_URL}" && ${PIP_INSTALL} "./${MF_WHL_NAME}"
### Install vLLM
VLLM_TAG=v0.8.5
git clone --depth 1 https://github.com/vllm-project/vllm.git --branch $VLLM_TAG
(cd vllm && VLLM_TARGET_DEVICE="empty" ${PIP_INSTALL} -v -e .)
### Install PyTorch and PTA
PYTORCH_VERSION=2.6.0
TORCHVISION_VERSION=0.21.0
${PIP_INSTALL} torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION --index-url https://download.pytorch.org/whl/cpu
PTA_VERSION="v7.1.0.1-pytorch2.6.0"
PTA_NAME="torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl"
PTA_URL="https://gitee.com/ascend/pytorch/releases/download/${PTA_VERSION}/${PTA_NAME}"
wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}"
### Install Triton-Ascend
TRITON_ASCEND_NAME="triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${TRITON_ASCEND_NAME}"
${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG="20250901"
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
### Install SGLang
${PIP_INSTALL} -v -e "python[srt_npu]"

View File

@@ -0,0 +1,293 @@
"""
Sync code from OSS repo to the local repo and open a PR if changes exist.
NOTE:
1. You need to execute this script in the git root folder.
2. A GH_TOKEN environment variable is required to create the pull request.
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
This script will:
1. Clone the sgl-project/sglang repository (or use a local copy).
2. Sync specified files and directories using rsync.
3. Check if the sync operation resulted in any changes.
4. If there are changes:
a. Create a new branch.
b. Commit and push the changes.
c. Open a pull request using the GitHub CLI (gh).
Usage:
# Run the full sync and PR creation process
python3 scripts/copy_from_oss.py
# Perform a dry run without making any actual changes
python3 scripts/copy_from_oss.py --dry-run
# Use a local directory as the source instead of cloning
python3 scripts/copy_from_oss.py --local-dir ~/projects/sglang
"""
import argparse
import datetime
import os
import shutil
import subprocess
import tempfile
# --- Configuration Begin ---
# List of folders and files to copy from the OSS repo.
# Changes outside these paths will be ignored.
folder_names = [
"3rdparty",
"assets",
"benchmark",
"docker",
"docs",
"examples",
"sgl-kernel",
"README.md",
"python/sglang/lang",
"python/sglang/srt",
"python/sglang/test",
"test/lang",
"test/srt",
]
private_repo = "your-org/sglang-private-repo"
# --- Configuration End ---
def write_github_step_summary(content):
if not os.environ.get("GITHUB_STEP_SUMMARY"):
return
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content)
def check_dependencies():
"""Check for required command-line tools."""
if not shutil.which("git"):
raise EnvironmentError("git is not installed or not in PATH.")
if not shutil.which("gh"):
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
print("✅ All dependencies (git, gh) are available.")
def checkout_main(dry_run):
"""Checkout to the main branch."""
commands = [
"git checkout main",
"git reset --hard",
]
for cmd in commands:
print(f"Run: {cmd}")
if not dry_run:
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr.decode()}")
raise
print("✅ Checkout the main branch.")
def get_source_folder(args):
"""
Prepare the source repository, either by cloning from GitHub or using a local directory.
Returns the path to the source repo root, a temporary directory path (if created),
and the short commit hash.
"""
temp_dir = None
if args.local_dir:
oss_root = os.path.expanduser(args.local_dir)
if not os.path.exists(oss_root):
raise FileNotFoundError(
f"Specified local directory {oss_root} does not exist."
)
print(f"Using local directory as the source: {oss_root}")
else:
temp_dir = tempfile.mkdtemp()
oss_root = temp_dir
print(f"Created temporary directory: {oss_root}")
repo_url = "https://github.com/sgl-project/sglang.git"
try:
subprocess.run(
[
"git",
"clone",
"--single-branch",
"--branch",
"main",
repo_url,
temp_dir,
],
check=True,
capture_output=True,
)
print(f"Successfully cloned repository to {temp_dir}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr.decode()}")
raise
commit_hash = subprocess.run(
["git", "-C", oss_root, "rev-parse", "HEAD"],
capture_output=True,
text=True,
check=True,
).stdout.strip()[:8]
print(f"✅ Get source OSS code at commit: {commit_hash}")
return oss_root, temp_dir, commit_hash
def sync_directories(oss_root, folder_names, dry_run):
"""Sync specified directories from oss_root to current working directory."""
rsync_commands = []
for folder_name in folder_names:
target_name = f"{oss_root}/{folder_name}"
src_name = "./" + "/".join(folder_name.split("/")[:-1])
cmd = f"rsync -r --delete {target_name} {src_name}"
rsync_commands.append(cmd)
for cmd in rsync_commands:
try:
print(f"Run: {cmd}")
if not dry_run:
subprocess.run(cmd, shell=True, check=True)
except subprocess.CalledProcessError as e:
print(f"Error executing command '{cmd}': {e}")
raise
print(f"✅ Sync all folders.")
def check_for_changes():
"""Check if there are any uncommitted git changes."""
# This command exits with 1 if there are changes, 0 otherwise.
result = subprocess.run(["git", "diff", "--quiet"])
return result.returncode != 0
def create_and_push_branch(branch_name, commit_message, dry_run):
"""Create a new branch, commit all changes, and push to origin."""
commands = [
f"git checkout -b {branch_name}",
"git config user.name 'github-actions[bot]'",
"git config user.email 'github-actions[bot]@users.noreply.github.com'",
"git add .",
f"git commit -m '{commit_message}'",
f"git push origin {branch_name} --force",
]
print("\nCreating and pushing git branch...")
for cmd in commands:
print(f"Run: {cmd}")
if not dry_run:
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr.decode()}")
raise
def create_pull_request(branch_name, title, body, dry_run):
"""Create a pull request using the GitHub CLI."""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print(
"\n⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
)
if not dry_run:
return
print("\nCreating pull request...")
command = [
"gh",
"pr",
"create",
"--base",
"main",
"--head",
branch_name,
"--repo",
private_repo,
"--title",
title,
"--body",
body,
]
print(f"Run: {' '.join(command)}")
if not dry_run:
env = os.environ.copy()
env["GH_TOKEN"] = gh_token
try:
result = subprocess.run(
command, check=True, capture_output=True, text=True, env=env
)
pr_url = result.stdout.strip()
msg = f"✅ Successfully created pull request: {pr_url}"
print(msg)
write_github_step_summary(msg)
except subprocess.CalledProcessError as e:
print(f"Error creating pull request: {e.stderr}")
raise
def main():
parser = argparse.ArgumentParser(
description="Copy code from OSS and open a PR if changes are detected."
)
parser.add_argument(
"--local-dir",
type=str,
help="Path to local SGLang directory to use instead of cloning from GitHub.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run the script without executing git, rsync, or gh commands.",
)
args = parser.parse_args()
check_dependencies()
checkout_main(args.dry_run)
oss_root, temp_dir, oss_commit = get_source_folder(args)
try:
# Sync directories
sync_directories(oss_root, folder_names, args.dry_run)
# Check for changes and create PR if necessary
if not check_for_changes():
msg = "😴 No changes detected. The code is already in sync."
print(msg)
write_github_step_summary(msg)
return
print("✅ Changes detected. Proceeding to create a PR.")
current_date = datetime.datetime.now().strftime("%Y%m%d")
branch_name = f"copy-from-oss-{oss_commit}-{current_date}"
commit_message = f"Copy OSS code from {oss_commit} on {current_date}"
pr_title = (
f"[Automated PR] Copy OSS code from commit {oss_commit} on {current_date}"
)
pr_body = (
f"Copy OSS code from https://github.com/sgl-project/sglang/commit/{oss_commit} on {current_date}."
"\n\n---\n\n"
"*This is an automated PR created by scripts/copy_from_oss.py.*"
)
create_and_push_branch(branch_name, commit_message, args.dry_run)
create_pull_request(branch_name, pr_title, pr_body, args.dry_run)
finally:
# Remove temporary directory if it was created
if temp_dir:
try:
shutil.rmtree(temp_dir)
print(f"\nRemoved temporary directory: {temp_dir}")
except OSError as e:
print(f"Error removing temporary directory {temp_dir}: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,425 @@
"""
Sync a specific commit from the local private repo to the OSS upstream and open a PR.
NOTE:
1. You need to execute this script in the git root folder.
2. A GH_TOKEN environment variable is required to create the pull request.
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
This script will:
1. Take a commit hash as an argument (or use the latest commit by default).
2. Create a patch for that commit.
3. Filter the patch to only include changes in specified directories.
4. Clone the sgl-project/sglang repository.
5. Create a new branch in the OSS repo.
6. Apply the filtered patch, commit, and force push.
7. Open a pull request to the OSS repo using the GitHub CLI (gh).
Usage:
# Sync the latest commit from the current branch
python3 scripts/copy_to_oss.py
# Run the full sync and PR creation process for a given commit
python3 scripts/copy_to_oss.py --commit <commit_hash>
# Perform a dry run without making any actual changes
python3 scripts/copy_to_oss.py --commit <commit_hash> --dry-run
"""
import argparse
import datetime
import os
import shutil
import subprocess
import tempfile
# --- Configuration Begin ---
# List of folders and files to copy to the OSS repo.
# Changes outside these paths will be ignored.
folder_names = [
"3rdparty",
"assets",
"benchmark",
"docker",
"docs",
"examples",
"sgl-kernel",
"README.md",
"python/sglang/lang",
"python/sglang/srt",
"python/sglang/test",
"test/lang",
"test/srt",
]
# --- Configuration End ---
def write_github_step_summary(content):
if not os.environ.get("GITHUB_STEP_SUMMARY"):
return
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content)
def get_commit_info(commit_ref):
"""
Retrieves the hash and message of a specific commit.
Args:
commit_ref (str): The commit hash, tag, or branch to inspect (e.g., 'HEAD').
Returns:
A tuple containing the (commit_hash, commit_message),
or (None, None) if an error occurs.
"""
try:
# Use a custom format to get the hash (%H) and the full message (%B)
# separated by a null character for safe parsing.
command = ["git", "log", "-1", f"--pretty=%H%x00%B", commit_ref]
result = subprocess.run(
command, capture_output=True, text=True, check=True, encoding="utf-8"
)
# Split the output by the null character separator
commit_hash, commit_message = result.stdout.strip().split("\x00", 1)
return commit_hash, commit_message
except FileNotFoundError:
print("❌ Error: 'git' command not found. Is Git installed and in your PATH?")
except subprocess.CalledProcessError as e:
print(f"❌ Error getting commit info for '{commit_ref}': {e.stderr.strip()}")
print(
"Hint: Make sure you are running this from within a Git repository and the commit exists."
)
return None, None
def check_dependencies():
"""Check for required command-line tools."""
if not shutil.which("git"):
raise EnvironmentError("git is not installed or not in PATH.")
if not shutil.which("gh"):
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
print("✅ All dependencies (git, gh) are available.")
def create_filtered_patch(commit_hash, dry_run):
"""
Create a patch file for the given commit, containing only changes
to files and directories specified in `folder_names`.
"""
print(f"Creating a filtered patch for commit {commit_hash}")
try:
# Get the list of all files changed in the commit
changed_files_raw = subprocess.run(
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout
changed_files = changed_files_raw.strip().split("\n")
# Filter the list of files
relevant_files = [
f for f in changed_files if any(f.startswith(path) for path in folder_names)
]
if not relevant_files:
msg = "\n😴 No relevant file changes found in this commit. Exiting."
print(msg)
write_github_step_summary(msg)
return None, None
print("Found relevant changes in the following files:")
for f in relevant_files:
print(f" - {f}")
# Create a patch containing only the changes for the relevant files
patch_command = [
"git",
"format-patch",
"--stdout",
f"{commit_hash}^..{commit_hash}",
"--",
] + relevant_files
print(f"Run: {' '.join(patch_command)}")
patch_content = subprocess.run(
patch_command, capture_output=True, text=True, check=True
).stdout
# Save the patch to a temporary file
patch_file = tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".patch", encoding="utf-8"
)
patch_file.write(patch_content)
patch_file.close()
print(f"✅ Filtered patch created successfully at: {patch_file.name}")
return patch_file.name, relevant_files
except subprocess.CalledProcessError as e:
print(f"Error creating patch: {e.stderr}")
raise
def get_oss_repo(dry_run):
"""
Clones the OSS repository into a temporary directory.
Returns the path to the repo root and the temp directory itself.
"""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print("⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation.")
if not dry_run:
return
temp_dir = tempfile.mkdtemp()
oss_root = os.path.join(temp_dir, "sglang")
print(f"\nCreated temporary directory for OSS repo: {temp_dir}")
repo_url = f"https://{gh_token}@github.com/sgl-project/sglang.git"
command = ["git", "clone", "--branch", "main", repo_url, oss_root]
print(f"Run: {' '.join(command)}")
if not dry_run:
try:
subprocess.run(command, check=True, capture_output=True)
print(f"✅ Successfully cloned repository to {oss_root}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr.decode()}")
shutil.rmtree(temp_dir)
raise
return oss_root, temp_dir
def apply_patch_and_push(oss_root, patch_file, branch_name, commit_message, dry_run):
"""
In the OSS repo, create a branch, apply the patch, commit, and push.
"""
print("\nApplying patch and pushing to OSS repo...")
original_cwd = os.getcwd()
if not dry_run:
os.chdir(oss_root)
try:
# Define commands as lists to avoid shell injection issues
commands_to_run = [
["git", "checkout", "-b", branch_name],
["git", "apply", patch_file],
["git", "config", "user.name", "github-actions[bot]"],
[
"git",
"config",
"user.email",
"github-actions[bot]@users.noreply.github.com",
],
["git", "add", "."],
]
for cmd_list in commands_to_run:
print(f"Run: {' '.join(cmd_list)}")
if not dry_run:
subprocess.run(cmd_list, check=True, capture_output=True, text=True)
# Handle commit separately to pass multi-line message safely via stdin
commit_cmd = ["git", "commit", "-F", "-"]
print(f"Run: {' '.join(commit_cmd)}")
if not dry_run:
print(f"Commit Message:\n---\n{commit_message}\n---")
subprocess.run(
commit_cmd,
input=commit_message,
text=True,
check=True,
capture_output=True,
)
# Push the changes
push_cmd = ["git", "push", "origin", branch_name, "--force"]
print(f"Run: {' '.join(push_cmd)}")
if not dry_run:
subprocess.run(push_cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr}")
raise
finally:
if not dry_run:
os.chdir(original_cwd)
print("✅ Branch created, patch applied, and pushed successfully.")
def create_pull_request(oss_root, branch_name, title, body, dry_run):
"""Create a pull request in the OSS repo using the GitHub CLI."""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print("⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation.")
if not dry_run:
return
print("\nCreating pull request...")
command = [
"gh",
"pr",
"create",
"--base",
"main",
"--head",
branch_name,
"--repo",
"sgl-project/sglang",
"--title",
title,
"--body",
body,
]
print(f"Run: {' '.join(command)}")
if not dry_run:
env = os.environ.copy()
env["GH_TOKEN"] = gh_token
try:
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
env=env,
cwd=oss_root,
)
msg = f"✅ Successfully created pull request: {result.stdout.strip()}"
print(msg)
write_github_step_summary(msg)
except subprocess.CalledProcessError as e:
print(f"Error creating pull request: {e.stderr}")
# Check if a PR already exists
if "A pull request for" in e.stderr and "already exists" in e.stderr:
print(" A PR for this branch likely already exists.")
else:
raise
def get_commit_author(commit_hash):
"""Get the author name and email of a commit."""
try:
author_name = subprocess.run(
["git", "show", "-s", "--format=%an", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout.strip()
author_email = subprocess.run(
["git", "show", "-s", "--format=%ae", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout.strip()
return author_name, author_email
except subprocess.CalledProcessError as e:
print(f"Error getting commit author for {commit_hash}: {e.stderr}")
raise
def main():
parser = argparse.ArgumentParser(
description="Copy a commit from the private repo to OSS and open a PR."
)
parser.add_argument(
"--commit",
type=str,
default="LAST",
help="The commit hash to sync. Defaults to 'LAST' to use the latest commit.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run the script without executing git, rsync, or gh commands.",
)
args = parser.parse_args()
check_dependencies()
commit_ref = "HEAD" if args.commit == "LAST" else args.commit
commit_hash, original_commit_message = get_commit_info(commit_ref)
if not commit_hash:
return # Exit if we couldn't get commit info
# Display the details of the commit being processed
if args.commit == "LAST":
summary = (
f"\n No commit specified. Using the last commit:\n"
f" - **Hash:** `{commit_hash}`\n"
f" - **Message:** {original_commit_message}\n\n"
)
else:
summary = (
f"\n Using specified commit:\n"
f" - **Hash:** `{commit_hash}`\n"
f" - **Message:** {original_commit_message}\n\n"
)
print(summary)
write_github_step_summary(summary)
short_hash = commit_hash[:8]
patch_file = None
temp_dir = None
try:
# 1. Create a filtered patch from the local repo
patch_file, relevant_files = create_filtered_patch(commit_hash, args.dry_run)
if not patch_file:
return
# 2. Get the OSS repo
oss_root, temp_dir = get_oss_repo(args.dry_run)
# 3. Get original commit author for the co-author line
author_name, author_email = get_commit_author(commit_hash)
# 4. Prepare content for the commit and PR based on changed files
file_list_str = "\n".join([f"- {f}" for f in relevant_files])
filename_list_str = ", ".join([f.split("/")[-1] for f in relevant_files])
if len(filename_list_str) > 40:
filename_list_str = filename_list_str[:40] + "..."
current_date = datetime.datetime.now().strftime("%Y%m%d")
pr_title = f"[Auto Sync] Update {filename_list_str} ({current_date})"
pr_body = (
f"Sync changes from commit `{short_hash}`.\n\n"
f"**Relevant Files Changed:**\n{file_list_str}"
"\n\n---\n\n"
"*This is an automated PR created by a script.*"
)
# 5. Create branch, apply patch, and push
branch_name = f"sync-{short_hash}-{current_date}"
co_author_line = f"Co-authored-by: {author_name} <{author_email}>"
commit_message = f"{pr_title}\n\n{co_author_line}"
apply_patch_and_push(
oss_root, patch_file, branch_name, commit_message, args.dry_run
)
# 6. Create Pull Request
create_pull_request(oss_root, branch_name, pr_title, pr_body, args.dry_run)
finally:
# Cleanup temporary files
if patch_file and os.path.exists(patch_file):
os.remove(patch_file)
print(f"\nRemoved temporary patch file: {patch_file}")
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print(f"Removed temporary directory: {temp_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,27 @@
### Sync Code Between OSS and Private Fork
You can use the following principles and tools to sync the code between a private fork and the OSS repo [sgl-project/sglang](https://github.com/sgl-project/sglang/tree/main).
It learns from [Copybara](https://github.com/google/copybara), a tool used at Google for maintaining open-source code synchronization.
## Principals
- The core folders (e.g., `python/sglang/srt`) are 100% mirrored between the private fork and OSS repo.
- The OSS repo is the single source of truth. If one commit changes `python/sglang/srt` in the private repo, the change should be synced to the OSS repo as soon as possible with the action B below.
- The common code (e.g., base classes, well-known techniques in the industry without private secrets) goes to `python/sglang/srt`. The private-specific code (e.g., with private-specific features, confidential info) goes to `python/sglang/private` .
- Anytime you want to make private changes to a file or class under `python/sglang/srt`, duplicate the file and move it under `python/sglang/private`. You can achieve code reuse by importing and inheriting.
## How to sync the code bidirectionally
### Action A: Copy code from OSS to private
- We can run this action: [Open A PR to Copy Code From OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-from-oss.yml)
- It opens a PR to copy all files under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) from the OSS main branch to the private fork.
- Since the OSS repo is the single source of truth, this action copies files and overwrites any changes in the private fork. To prevent the private changes from being overwritten, you need to ensure all private changes are merged into the OSS repo before running this action.
- This action will be run automatically every day and can also be triggered manually.
### Action B: Copy diff from private to OSS
- We can run this action: [Open A PR to Copy Code To OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-to-oss.yml)
- It opens a PR to apply the diff of one specific commit of the private fork to the OSS main branch. It will only pick the changes under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) and ignore changes under private folders (e.g., `python/sglang/private` )
- For example, you can have a PR that changes both `python/sglang/srt` and `python/sglang/private/srt`. Once you merge the PR into the private repo, `python/sglang/srt` becomes desynced between the two repos. You need to run this action on your merge commit immediately to open a PR to send your diff to the OSS repo. Then, we need to merge the OSS PR as soon as possible. Once your OSS PR is merged, we can run action A again.
- Action A copies files directly, but Action B applies diff. This is because OSS is the source of truth; action A can just copy files. Action B cannot copy, so it uses diff instead.
- This action currently needs a manual trigger in order to prevent incidental code leaks. One can also consider making it automatic.

View File

@@ -0,0 +1,18 @@
#!/bin/bash
# Check if gh is installed before attempting to install it
if ! command -v gh &> /dev/null
then
echo "GitHub CLI not found. Installing now..."
(type -p wget >/dev/null || ( apt update && apt install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
&& cat $out | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& mkdir -p -m 755 /etc/apt/sources.list.d \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update \
&& apt install gh -y
else
echo "GitHub CLI is already installed. Skipping installation."
fi

View File

@@ -0,0 +1,38 @@
"""
Convert Yi-VL config into a format usable with SGLang
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
"""
import argparse
import json
import os
from transformers import AutoConfig, AutoTokenizer
def add_image_token(model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
print(tokenizer)
tokenizer.save_pretrained(model_path)
def edit_model_config(model_path):
config = AutoConfig.from_pretrained(model_path)
setattr(config, "architectures", ["YiVLForCausalLM"])
setattr(config, "image_token_index", 64002)
print(config)
config.save_pretrained(model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
args = parser.parse_args()
add_image_token(args.model_path)
edit_model_config(args.model_path)

View File

@@ -0,0 +1,13 @@
# For 34B Model
mkdir ~/model_weights
cd ~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-34B
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B
# For 6B Model
mkdir ~/model_weights
cd ~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-6B
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B

View File

@@ -0,0 +1,9 @@
curl http://localhost:30000/generate \
-H "Content-Type: application/json" \
-d '{
"text": "Once upon a time,",
"sampling_params": {
"max_new_tokens": 64,
"temperature": 0
}
}'

View File

@@ -0,0 +1,217 @@
import pytest
import torch
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.utils import should_use_tensor_core
flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
@pytest.mark.parametrize("batch_size", [12, 37, 67])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [32, 4])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
init_flashinfer(num_qo_heads, num_kv_heads)
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
v_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
o_triton = torch.empty_like(q)
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
max_len_extend = qo_len
extend_attention_fwd(
q,
k_extend,
v_extend,
o_triton,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
None, # b_start_loc = None
b_seq_len,
None, # b_seq_len_prefix = None
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
o_redundant = torch.empty_like(q)
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
b_seq_len_prefix = b_seq_len - b_seq_len_extend
redundant_attention(
q,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
flashinfer_prefill_wrapper.end_forward()
flashinfer_prefill_wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
o = flashinfer_prefill_wrapper.forward(
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 37])
@pytest.mark.parametrize("kv_len", [54, 127, 537])
@pytest.mark.parametrize("num_kv_heads", [32])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
# to test different shape of decode, change the parameters in the __main__, and run decode only once
init_flashinfer(num_qo_heads, num_kv_heads)
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
o_triton = torch.empty_like(q)
req_to_token = (
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
other_kv_index = 0
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o_triton,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_tokens,
)
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type="float16",
)
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
def init_flashinfer(num_attention_heads, num_kv_heads):
use_tensor_cores = should_use_tensor_core(
torch.half, num_attention_heads, num_kv_heads
)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)

View File

@@ -0,0 +1,56 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def main(args):
url = f"{args.host}:{args.port}"
task1 = send_request(
url + "/generate",
{
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
delay=1,
)
task2 = send_request(
url + "/generate",
{
"text": "The capital of the United Kindom is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
)
rets = await asyncio.gather(task1, task2)
print(rets)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,55 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import requests
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode(url)
test_decode(url, n=3)
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
test_decode(
url,
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)

View File

@@ -0,0 +1,68 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import requests
def test_decode_stream(url, return_logprob, top_logprobs_num):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": True,
"logprob_start_len": 0,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if return_logprob:
assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None
for logprob, token_id, token_text in data["meta_info"][
"output_token_logprobs"
][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
prev = len(data["meta_info"]["output_token_logprobs"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False, 0)
test_decode_stream(url, True, 0)
test_decode_stream(url, True, 3)

View File

@@ -0,0 +1,88 @@
"""
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output:
The image features a man standing on the back of a yellow taxi cab, holding
"""
import argparse
import asyncio
import json
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
response = []
for i in range(8):
response.append(
send_request(
url + "/generate",
{
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)

View File

@@ -0,0 +1,42 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is Paris.\nThe capital of the United States is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())

View File

@@ -0,0 +1,138 @@
import argparse
from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained.outlines_backend import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_jump_forward = (
r"The google's DNS sever address is "
+ IP_REGEX
+ r" and "
+ IP_REGEX
+ r". "
+ r"The google's website domain name is "
+ r"www\.(\w)+\.(\w)+"
+ r"."
)
# fmt: off
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
max_tokens=128,
temperature=0,
regex=ip_jump_forward,
)
# fmt: on
json_jump_forward = (
r"""The information about Hogwarts is in the following JSON format\.\n"""
+ r"""\n\{\n"""
+ r""" "name": "[\w\d\s]*",\n"""
+ r""" "country": "[\w\d\s]*",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
+ r""" "population": [-+]?[0-9]+,\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
+ r"""\}\n"""
)
# fmt: off
@sgl.function
def json_gen(s):
s += sgl.gen(
"json",
max_tokens=128,
temperature=0,
regex=json_jump_forward,
)
# fmt: on
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
@sgl.function
def character_gen(s):
s += "Give me a character description who is a wizard.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Character),
)
def main(args):
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
state = regex_gen.run(temperature=0)
print("=" * 20, "IP TEST", "=" * 20)
print(state.text())
state = json_gen.run(temperature=0)
print("=" * 20, "JSON TEST", "=" * 20)
print(state.text())
state = character_gen.run(temperature=0)
print("=" * 20, "CHARACTER TEST", "=" * 20)
print(state.text())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
main(args)
# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.
# {
# "name": "Hogwarts School of Witchcraft and Wizardry",
# "country": "Scotland",
# "latitude": 55.566667,
# "population": 1000,
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }
# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }

View File

@@ -0,0 +1,132 @@
import argparse
import random
import string
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
TOKENIZER = None
RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None
def gen_prompt(token_num):
if RANDOM_PREFILL_LEN:
token_num = random.randint(1, token_num)
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(TOKENIZER(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def robust_test_dfs(s, d, args, leaf_states):
if d == 0:
s += "END"
leaf_states.append(s)
return
s += gen_prompt(args.len_prefill)
forks = s.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
for fork_s in forks:
robust_test_dfs(fork_s, d - 1, args, leaf_states)
def robust_test_bfs(s, args, leaf_states):
old_forks = [s]
new_forks = []
for _ in range(args.depth):
for old_fork in old_forks:
old_fork += gen_prompt(args.len_prefill)
forks = old_fork.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
new_forks.extend(forks)
old_forks = new_forks
new_forks = []
for old_fork in old_forks:
old_fork += "END"
leaf_states.append(old_fork)
@sgl.function
def robust_test(s, args):
leaf_states = []
if args.mode == "bfs":
robust_test_bfs(s, args, leaf_states)
else:
robust_test_dfs(s, args.depth, args, leaf_states)
return leaf_states
def main(args):
backend = select_sglang_backend(args)
arguments = [{"args": args} for _ in range(args.num_req)]
states = robust_test.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel
)
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
for state in states:
leaf_states = state.ret_value
for leaf_state in leaf_states:
assert leaf_state.text()[-3:] == "END"
f.write(leaf_state.text()[:-3] + "\n")
if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--num-req", type=int, default=2)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--num-fork", type=int, default=2)
parser.add_argument("--len-prefill", type=int, default=128)
parser.add_argument("--len-decode", type=int, default=128)
parser.add_argument("--random-prefill-len", action="store_true")
parser.add_argument("--random-decode-len", action="store_true")
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = add_common_sglang_args_and_parse(parser)
# fmt: on
RANDOM_PREFILL_LEN = args.random_prefill_len
RANDOM_DECODE_LEN = args.random_decode_len
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
random.seed(args.seed)
main(args)

View File

@@ -0,0 +1,115 @@
"""
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
Usage:
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
"""
import argparse
import json
import os
import shutil
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import AutoConfig
def get_nextn_layer_id(config):
if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.")
return config.num_hidden_layers
def update_and_save_config(config, output_dir):
new_config = config.to_dict()
new_config.update(
{
"num_hidden_layers": 1,
"architectures": ["DeepseekV3ForCausalLMNextN"],
}
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True)
def copy_non_safetensors_files(input_dir, output_dir):
for filename in os.listdir(input_dir):
src_file_path = os.path.join(input_dir, filename)
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
dst_file_path = os.path.join(output_dir, filename)
shutil.copy2(src_file_path, dst_file_path)
print(f"All non-safetensors files have been copied to {output_dir}")
def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
prefix = f"model.layers.{nextn_layer_id}"
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
params = {}
for filename in os.listdir(input_dir):
if not filename.endswith(".safetensors"):
continue
file_path = os.path.join(input_dir, filename)
print(f"Processing: {filename}")
try:
with safe_open(file_path, framework="pt") as f:
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
if not matching_keys:
print(f" No parameters starting with '{prefix}' found")
continue
for key in matching_keys:
if "embed_tokens" in key or "shared_head.head" in key:
continue
new_key = key.replace(prefix, "model.layers.0")
params[new_key] = f.get_tensor(key)
except Exception as e:
print(f" Error processing {filename}: {str(e)}")
if params:
print(f"Saving {len(params)} parameters to {output_path}")
save_file(params, output_path)
else:
print("No matching parameters found.")
# Update safetensors index
index_path = os.path.join(output_dir, "model.safetensors.index.json")
print(f"Updating safetensors index to {index_path}")
index_data = {"weight_map": {}}
for key in params:
index_data["weight_map"][key] = "nextn_layer_parameters.safetensors"
with open(index_path, "w") as f:
json.dump(index_data, f, indent=4)
print("All done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export NextN layer parameters for DeepSeek-V3/R1"
)
parser.add_argument(
"--input-dir",
type=str,
required=True,
help="Input HF model directory.",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output nextn model directory.",
)
args = parser.parse_args()
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
nextn_layer_id = get_nextn_layer_id(config)
os.makedirs(args.output_dir, exist_ok=True)
copy_non_safetensors_files(args.input_dir, args.output_dir)
update_and_save_config(config, args.output_dir)
export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)

32
scripts/killall_sglang.sh Executable file
View File

@@ -0,0 +1,32 @@
#!/bin/bash
if [ "$1" = "rocm" ]; then
echo "Running in ROCm mode"
# Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
else
# Show current GPU status
nvidia-smi
# Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
# Clean all GPU processes if any argument is provided
if [ $# -gt 0 ]; then
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y lsof
else
apt-get update
apt-get install -y lsof
fi
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
fi
# Show GPU status after clean up
nvidia-smi
fi

View File

@@ -0,0 +1,308 @@
"""
Usage:
# single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
# multiple GPU
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
"""
import argparse
import asyncio
import json
import os
import time
from types import SimpleNamespace
import numpy as np
import requests
from transformers import AutoTokenizer
from sglang.bench_serving import (
DatasetRow,
benchmark,
sample_mmmu_requests,
set_global_args,
)
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
kill_process_tree,
popen_launch_server,
)
def node0_print(msg):
if server_args.node_rank == 0:
print(msg)
prompts = [
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if its children then you can talk about animals; If its adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
]
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
if is_multimodal:
input_requests = sample_mmmu_requests(
num_prompts,
tokenizer,
512,
apply_chat_template=False,
)
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
input_requests: List[DatasetRow] = [
DatasetRow(p, 0, 512) for p in padded_prompts
]
backend = "sglang"
api_url = f"{base_url}/generate"
# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
disable_ignore_eos=False,
disable_stream=False,
return_logprob=False,
backend=backend,
dataset_name="custom",
num_prompts=None,
sharegpt_output_len=None,
random_input_len=None,
random_output_len=None,
random_range_ratio=None,
output_file=None,
warmup_requests=1,
output_details=False,
)
set_global_args(args)
# Run benchmark
results = asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id="default",
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=float("inf"),
max_concurrency=batch_size,
disable_tqdm=False,
lora_names=None,
extra_request_body={},
profile=None,
)
)
assert results["completed"] == len(input_requests)
acc_length = results["accept_length"] or 1.0
avg_output_token = results["total_output_tokens"] / results["completed"]
server_info = requests.get(base_url + "/get_server_info").json()
# We use 20% percentile instead of median on purpose
step_time = np.percentile(
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
)
speed = 1 / step_time * acc_length
return (
round(acc_length, 3),
round(step_time, 5),
round(speed, 3),
avg_output_token,
)
def main(args, server_args):
base_url = "http://127.0.0.1:20000"
configs = []
for batch_size in args.batch_size:
for steps in args.steps:
for topk in args.topk:
for num_draft_tokens in args.num_draft_tokens:
if steps * topk + 1 < num_draft_tokens:
continue
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
steps + topk + num_draft_tokens != 0
):
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
continue
configs.append((batch_size, steps, topk, num_draft_tokens))
for i in range(args.start, args.end or len(configs)):
batch_size, steps, topk, num_draft_tokens = configs[i]
node0_print(
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
)
# Create an LLM.
if steps == 0:
other_args = []
else:
other_args = [
"--speculative-num-steps",
steps,
"--speculative-eagle-topk",
topk,
"--speculative-num-draft-tokens",
num_draft_tokens,
]
if server_args.speculative_draft_model_path is not None:
other_args.extend(
[
"--speculative-draft-model-path",
server_args.speculative_draft_model_path,
"--speculative-algorithm",
server_args.speculative_algorithm,
]
)
other_args.extend(
[
"--cuda-graph-max-bs",
batch_size,
"--mem-fraction-static",
server_args.mem_fraction_static,
"--tp-size",
server_args.tp_size,
"--max-running-requests",
batch_size,
]
)
if server_args.trust_remote_code:
other_args.extend(
[
"--trust-remote-code",
]
)
if server_args.attention_backend:
other_args.extend(
[
"--attention-backend",
server_args.attention_backend,
]
)
if server_args.quantization:
other_args.extend(
[
"--quantization",
server_args.quantization,
]
)
process = popen_launch_server(
args.model_path,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
"SGLANG_RECORD_STEP_TIME": "1",
**os.environ,
},
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
try:
# Warmup
send_one_batch(
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
)
# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url,
max(args.num_prompts, batch_size),
batch_size,
tokenizer,
args.is_multimodal,
)
finally:
kill_process_tree(process.pid)
node0_print(
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
)
record = {
"batch_size": batch_size,
"steps": steps,
"topk": topk,
"num_draft_tokens": num_draft_tokens,
"acc_length": acc_length,
"step_time": step_time,
"speed": speed,
"completion_tokens": completion_tokens,
}
with open(args.output, "a") as fout:
fout.write(json.dumps(record) + "\n")
# Wait for the server to shutdown
time.sleep(5)
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--batch-size",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16),
)
parser.add_argument(
"--steps",
type=int,
nargs="+",
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
)
parser.add_argument(
"--topk",
type=int,
nargs="+",
default=(0, 1, 2, 4, 8),
)
parser.add_argument(
"--num_draft_tokens",
type=int,
nargs="+",
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
)
parser.add_argument("--num-prompts", type=int, default=16)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl")
parser.add_argument("--is-multimodal", action="store_true", default=False)
args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args)
main(args, server_args)

View File

@@ -0,0 +1,22 @@
prompt = "The capital of france is "
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
print(len(input_logprobs), len(output_logprobs))

View File

@@ -0,0 +1,34 @@
import json
import requests
port = 8000
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print(response.json())
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100

View File

@@ -0,0 +1,29 @@
import json
import requests
prompt = """
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
Give your honest take on the above text:
"""
response = requests.post(
"http://0.0.0.0:8000/generate",
json={"text": prompt, "sampling_params": {"temperature": 0}},
)
response_json = response.json()
print(response_json["text"])

View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch A Server\n",
"\n",
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
"from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n",
"from sglang import RuntimeEndpoint, set_default_backend\n",
"from sglang.srt.utils import load_image\n",
"from sglang.test.test_utils import is_in_ci\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
"\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")\n",
"print(f\"Server started on http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_default_backend(\n",
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += gen(\"answer\", max_tokens=512)\n",
" s += assistant_end()\n",
"\n",
"\n",
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa_separate_reasoning(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
" s += assistant_end()\n",
"\n",
"\n",
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
"print_highlight(\n",
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")\n",
"\n",
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in multi-turn conversations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def multi_turn_qa(s):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" return s\n",
"\n",
"\n",
"reasoning_state = multi_turn_qa()\n",
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
")\n",
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using No thinking as Qwen 3's advanced feature \n",
"\n",
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
"\n",
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reasoning_state = basic_qa_separate_reasoning(\n",
" \"List 3 countries and their capitals. /no_think\"\n",
")\n",
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in regular expression generation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def regular_expression_gen(s):\n",
" s += user(\n",
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
" )\n",
" s += assistant(\n",
" separate_reasoning(\n",
" gen(\n",
" \"answer\",\n",
" temperature=0,\n",
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
" max_tokens=512,\n",
" ),\n",
" model_type=\"qwen3\",\n",
" ),\n",
" )\n",
"\n",
"\n",
"reasoning_state = regular_expression_gen()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,7 @@
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.1.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000

View File

@@ -0,0 +1,14 @@
import argparse
import code
from sglang.srt.hf_transformers_utils import get_tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
args = parser.parse_args()
t = get_tokenizer(args.name)
code.interact(local=locals())

View File

@@ -0,0 +1,36 @@
from urllib.request import urlopen
from openai import OpenAI
test_cases = {
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
}
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
for name, url in test_cases.items():
print(f"\n==== Running test case: {name} ====")
try:
with urlopen(url, timeout=10) as response:
prompt = response.read().decode("utf-8")
except Exception as e:
print(f"Failed to load prompt for {name}: {e}")
continue
try:
response = client.chat.completions.create(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
messages=[{"role": "user", "content": prompt}],
stream=True,
max_tokens=128,
temperature=0,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
except Exception as e:
print(f"\nError during completion for {name}: {e}")

View File

@@ -0,0 +1,77 @@
import glob
import json
import os
import re
import sys
from tqdm import tqdm
sys.path.append("../../")
from fix_corrupted_json import clean_json_file
dirpath = "/Users/ying"
output_file_prefix = "analyzed_log"
time = {}
tot_time = {}
size = {}
os.system(f"rm {output_file_prefix}*")
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
print(dirname)
trace_name = dirname.split("/")[-1]
time[trace_name] = {}
size[trace_name] = {}
total_time = 0
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
step_name = filename.split("/")[-1].split(".")[0]
step_name = "_".join(step_name.split("_")[1:])
if "prefill" not in filename and "decode" not in filename:
continue
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
if match:
phase = match.group(1)
step = match.group(2)
else:
raise Exception(f"Cannot parse {filename}")
try:
with open(filename, "r") as f:
trace = json.load(f)
except:
clean_json_file(filename, filename)
with open(filename, "r") as f:
trace = json.load(f)
for event in trace["traceEvents"]:
name = event["name"]
if name in ["profile_prefill_step", "profile_decode_step"]:
dur = event["dur"] / 1e3
time[trace_name][step_name] = dur
break
total_time += dur
step = int(step_name.split("_")[-1])
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
size_info = json.load(f)
size[trace_name][step_name] = size_info["size"]
tot_time[trace_name] = total_time
time[trace_name] = dict(
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
size[trace_name] = dict(
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
for k, v in time[trace_name].items():
size_v = size[trace_name][k]
print(f"{k:>15}{v:10.2f}\t{size_v}")
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
with open(f"{output_file_prefix}_total_time", "w") as f:
print(tot_time)
json.dump(tot_time, f)

View File

@@ -0,0 +1,62 @@
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
ADAPTER = "/home/ying/test_lora"
HF_TOKEN = "..."
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
base_model = LlamaForCausalLM.from_pretrained(
MODEL,
device_map="auto",
# load_in_8bit=True,
torch_dtype=torch.float16,
# use_auth_token=HF_TOKEN,
).cuda()
# base model generate
with torch.no_grad():
output_tensors = base_model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= base output ========")
print(output)
# peft model generate
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16,
is_trainable=False,
)
with torch.no_grad():
output_tensors = model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= peft output ========")
print(output)

View File

@@ -0,0 +1,30 @@
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER = "/home/ying/test_lora"
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
llm = LLM(model=MODEL, enable_lora=True)
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
prompts = [prompt]
outputs = llm.generate(
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
)
print(outputs[0].prompt)
print(outputs[0].outputs[0].text)

View File

@@ -0,0 +1,197 @@
"""
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
--model-type {text,vlm}: Model type, text or vlm (default: text)
--max-new-tokens NUM: Max new tokens to generate (default: 16)
--dtype DTYPE: Data type for computation (default: float16)
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
Reference output:
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of
========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the
"""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.no_grad()
def vlm_text_with_image(args):
# Load the processor and model for ImageTextToText tasks
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
torch.cuda.set_device(0)
# List of image URLs to process
image_urls = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
# Conversation template for the processor
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
},
{"type": "text", "text": "Describe this image."},
],
}
]
max_new_tokens = args.max_new_tokens
for i, url in enumerate(image_urls):
# Load the image from the URL
image = Image.open(requests.get(url, stream=True).raw)
# Apply the chat template to the text prompt
# Notice that not all processors support chat templates.
# LLaVA and QWen are two processors that support chat templates.
if not hasattr(processor, "apply_chat_template"):
raise ValueError("The processor does not support chat templates.")
text_prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True
)
# Prepare inputs for the model
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
"cuda:0"
)
# Generate output from the model
output_ids = model.generate(
**inputs, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = processor.decode(output_ids[0])
# Get the logits from the model's forward pass
outputs = model.forward(**inputs)
logits = outputs.logits[0, -1, :]
print(f"\n========== Image {i} ==========")
print("prefill logits (final)", logits)
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
# making it cluttered and difficult to read.
# These tokens should be removed or cleaned up for better readability.
print(output_str)
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = args.max_new_tokens
torch.cuda.set_device(0)
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else:
input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
prefill_logits = m.forward(input_ids).logits[0][-1]
print(f"\n========== Prompt {i} ==========")
print("prefill logits (final)", prefill_logits)
print(output_str)
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
)
parser.add_argument("--max-new-tokens", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--model-type", type=str, default="text")
args = parser.parse_args()
if args.model_type == "vlm":
vlm_text_with_image(args)
else:
normal_text(args)

View File

@@ -0,0 +1,151 @@
"""
Usage:
# replay from a folder
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
# replay from a single file
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
"""
import argparse
import glob
import json
import pickle
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from datetime import datetime
import requests
from sglang.bench_serving import set_ulimit
from sglang.utils import get_exception_traceback
def read_records(files):
records = []
for f in files:
tmp = pickle.load(open(f, "rb"))
if isinstance(tmp, dict) and "requests" in tmp:
records.extend(tmp["requests"])
else:
records.extend(tmp)
return records
def run_one_request_internal(record):
(req, output, replay_init_time, start_time, end_time, idx) = record
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
else:
recorded_completion_tokens = ""
json_data = asdict(req)
stream = json_data["stream"]
if args.ignore_eos:
json_data["sampling_params"]["ignore_eos"] = True
if recorded_completion_tokens:
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=stream,
)
if stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
print(
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
f"{completion_tokens=}, {recorded_completion_tokens=}"
)
def run_one_request(record):
# global success_ct, error_ct
try:
run_one_request_internal(record)
# success_ct += 1
except Exception:
# error_ct += 1
traceback = get_exception_traceback()
print(f"Hit an exception: {traceback}")
def main(records):
if len(records) == 0:
return
base_time = records[0][-2]
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
print(f"{base_time_str=}")
replay_init_time = time.time()
for i in range(len(records)):
req, output, start_time, end_time = records[i]
start_time -= base_time
records[i] = (req, output, replay_init_time, start_time, end_time, i)
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_one_request, records)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--input-folder", type=str, default=None, help="Folder containing pickle files"
)
parser.add_argument(
"--input-file", type=str, default=None, help="Single pickle file to process"
)
parser.add_argument("--file-number", type=int, default=1)
parser.add_argument("--req-number", type=int, default=1000000)
parser.add_argument("--req-start", type=int, default=0)
parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args()
set_ulimit()
files = []
if args.input_file:
files = [args.input_file]
if args.file_number > 1:
print("Warning: --file-number is ignored when --input-file is provided.")
elif args.input_folder:
files = glob.glob(f"{args.input_folder}/*.pkl")
files = files[: args.file_number]
else:
print("Error: Either --input-folder or --input-file must be provided.")
exit(1)
print(f"{files=}")
records = read_records(files)
# Sort by the receive time, before filtering
records.sort(key=lambda x: x[-2])
records = records[args.req_start :]
if args.idx:
records = [records[args.idx]]
print(f"testing {args.idx=}")
print(f"{records[0]}")
print(f"{len(records)=}")
main(records)

View File

@@ -0,0 +1,207 @@
import random
import string
import time
import unittest
from typing import Dict, List, Tuple
from tree import MultiTenantRadixTree
class TestMultiTenantRadixTree(unittest.TestCase):
def setUp(self):
self.tree = MultiTenantRadixTree()
def test_insert_exact_match(self):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self.tree.insert("hello", "tenant1")
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
# Insert same string for different tenant
self.tree.insert("hello", "tenant2")
matched, tenant = self.tree.prefix_match("hello")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Insert different string for same tenant
self.tree.insert("world", "tenant1")
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant1")
print(self.tree.pretty_print())
def test_insert_partial_match(self):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self.tree.insert("hello", "tenant1")
print(self.tree.pretty_print())
self.tree.insert("help", "tenant2")
print(self.tree.pretty_print())
# Match exact strings
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
matched, tenant = self.tree.prefix_match("help")
self.assertEqual(matched, "help")
self.assertEqual(tenant, "tenant2")
# Match partial string
matched, tenant = self.tree.prefix_match("hel")
self.assertEqual(matched, "hel")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Match longer string
matched, tenant = self.tree.prefix_match("hello_world")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
def test_insert_edge_cases(self):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self.tree.insert("", "tenant1")
matched, tenant = self.tree.prefix_match("")
self.assertEqual(matched, "")
self.assertEqual(tenant, "tenant1")
# Single character
self.tree.insert("a", "tenant1")
matched, tenant = self.tree.prefix_match("a")
self.assertEqual(matched, "a")
self.assertEqual(tenant, "tenant1")
# Very long string
long_str = "a" * 1000
self.tree.insert(long_str, "tenant1")
matched, tenant = self.tree.prefix_match(long_str)
self.assertEqual(matched, long_str)
self.assertEqual(tenant, "tenant1")
# Unicode characters
self.tree.insert("你好", "tenant1")
matched, tenant = self.tree.prefix_match("你好")
self.assertEqual(matched, "你好")
self.assertEqual(tenant, "tenant1")
def test_simple_eviction(self):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size = {"tenant1": 10, "tenant2": 5}
# Insert strings for both tenants
self.tree.insert("hello", "tenant1") # size 5
self.tree.insert("hello", "tenant2") # size 5
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
# Evict - should remove "hello" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
# Verify "world" remains for tenant2 (was accessed more recently)
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant2")
def test_medium_eviction(self):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size = {
"tenant1": 10,
"tenant2": 6,
} # tenant2 will need to evict one string
# Create a tree with overlapping prefixes
self.tree.insert("hi", "tenant1")
self.tree.insert("hi", "tenant2") # OLDEST for t2
self.tree.insert("hello", "tenant1")
self.tree.insert("hello", "tenant2")
self.tree.insert("hey", "tenant2") # NEWEST for t2
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
self.assertEqual(
sizes_before["tenant2"], 7
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
print("\nTree before eviction:")
print(self.tree.pretty_print())
# Evict - should remove "hi" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
print("\nTree after eviction:")
print(self.tree.pretty_print())
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
def test_advanced_eviction(self):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
for i in range(100):
for j, prefix in enumerate(prefixes):
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
sizes_before = self.tree.get_used_size_per_tenant()
print(sizes_before)
self.tree.evict_tenant_data(max_size)
sizes_after = self.tree.get_used_size_per_tenant()
print(sizes_after)
# ensure size_after is below max_size
for tenant, size in sizes_after.items():
self.assertLessEqual(size, max_size[tenant])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,292 @@
import time
from collections import defaultdict
from typing import Dict, List
class Node:
def __init__(self):
self.children: Dict[str, Node] = dict()
# We choose to use text because most of the use cases are text-to-text,
# so we can save the tokenizing overhead.
self.text: str = ""
# Maps tenant_id to their last access timestamp
self.tenant_last_access_time: Dict[str, float] = dict()
self.parent = None
def shared_prefix_length(s1, s2):
min_length = min(len(s1), len(s2))
for i in range(min_length):
if s1[i] != s2[i]:
return i
return min_length
class MultiTenantRadixTree:
"""
Python Reference of Rust implementation of MultiTenantRadixTree
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
while maintaining tenant isolation.
Key concepts:
- Tenant: An entity that owns a subset of the stored strings
- Each node tracks which tenants have access to it via tenant_last_access_time
- The tree structure is shared, but queries can be filtered by tenant_id
"""
def __init__(self):
self.root = Node()
def insert(self, s: str, tenant_id: str) -> None:
"""
Insert string 's' and associate it with the given tenant_id.
Args:
s: The string to insert
tenant_id: The identifier of the tenant who owns this string
"""
curr = self.root
curr_idx = 0
curr.tenant_last_access_time[tenant_id] = time.time()
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
# No match => create a new node
new_node = Node()
new_node.text = s[curr_idx:]
new_node.parent = curr
curr.children[s[curr_idx]] = new_node
curr_idx = len(s)
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
else:
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
# 1. If the matched text is shorter than the node text => split the node
if shared_len < len(matched_node.text):
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
matched_text = matched_node.text[:shared_len]
unmatched_text = matched_node.text[shared_len:]
new_node = Node()
new_node.text = matched_text
new_node.children = {unmatched_text[0]: matched_node}
new_node.parent = curr
new_node.parent.children[matched_text[0]] = new_node
new_node.tenant_last_access_time = (
matched_node.tenant_last_access_time.copy()
)
# Contract matched node
matched_node.text = unmatched_text
matched_node.parent = new_node
curr_idx += shared_len
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
# 2. If the matched text is longer or equal to the node text => walk down the node
else:
curr_idx += shared_len
curr = matched_node
curr.tenant_last_access_time[tenant_id] = time.time()
def prefix_match(self, s: str) -> tuple[str, int]:
"""
Match string 's' with multiple tenants' trees in one operation.
Args:
s: The string to match
Returns:
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
"""
curr = self.root
curr_idx = 0
ret_text = ""
ret_tenant = None
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
break
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
if shared_len == len(matched_node.text):
curr_idx += shared_len
curr = matched_node
else:
curr_idx += shared_len
curr = matched_node
break
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
# traverse back to the root to update last access time for the selected tenant
while curr != self.root:
curr.tenant_last_access_time[selected_tenant] = time.time()
curr = curr.parent
return s[:curr_idx], selected_tenant
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
"""
Evict data for tenants that have exceeded their storage limits.
Args:
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
"""
def leaf_of(node):
"""
If the node is a leaf for a tenant, add tenant_id to the return list
This will return list of tenant ids
If not a leaf for all tenants, return []
"""
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
for n in node.children.values():
for c in n.tenant_last_access_time.keys():
candidates[c] = False
return [k for k, v in candidates.items() if v]
# maintain a heap with (time, tenant, node) as the value
import heapq
# 1. traverse the tree to
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
# b. calculate the used size for each tenant
# do a dfs with stack
stack = [self.root]
pq = []
used_size_per_tenant = defaultdict(int)
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
# if the node is a leaf for a tenant, add the tenant to the heap
tenants = leaf_of(curr)
for t in tenants:
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
# 2. pop the heap
# a. if the tenant's used size is less than the limit, continue
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
while len(pq) > 0:
time, tenant, node = heapq.heappop(pq)
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
continue
# remove the leaf
used_size_per_tenant[tenant] -= len(node.text)
del node.tenant_last_access_time[tenant]
# if no children and no tenants, remove the node
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
del node.parent.children[node.text[0]]
# add its parent to the heap
if tenant in leaf_of(node.parent):
heapq.heappush(
pq,
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
)
def get_used_size_per_tenant(self) -> Dict[str, int]:
"""
Calculate the used storage size for each tenant.
Returns:
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
"""
used_size_per_tenant = defaultdict(int)
stack = [self.root]
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
return used_size_per_tenant
def remove_tenant(self, tenant_id: str) -> None:
"""
Remove all data associated with a specific tenant from the tree.
This operation maintains the integrity of the shared tree structure while
removing only the specified tenant's access information.
Args:
tenant_id: The identifier of the tenant whose data should be removed
"""
# TODO: Implementation needed
pass
def pretty_print(self) -> str:
"""
Returns a string representation of the tree showing the structure, tenant ownership,
and leaf status for each node.
Returns:
str: A formatted string showing the tree hierarchy with tenant information
"""
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
# Current node representation
node_str = prefix
node_str += "└── " if is_last else "├── "
# Add node text
node_str += f"'{node.text}' ["
# Add tenant information including both timestamp and leaf status
tenant_info = []
for tid, ts in node.tenant_last_access_time.items():
time_str = (
time.strftime("%H:%M:%S.", time.localtime(ts))
+ f"{(ts % 1):0.3f}"[2:]
)
tenant_info.append(f"{tid} | {time_str}")
node_str += ", ".join(tenant_info)
node_str += "]\n"
# Handle children
children = list(node.children.items())
for i, (char, child) in enumerate(children):
is_last_child = i == len(children) - 1
# Adjust prefix for children based on whether this is the last child
new_prefix = prefix + (" " if is_last else "")
node_str += _node_to_str(child, new_prefix, is_last_child)
return node_str
if not self.root.children:
return "Empty tree"
# Start with root's children since root itself is just an empty node
result = ""
children = list(self.root.children.items())
for i, (char, child) in enumerate(children):
is_last = i == len(children) - 1
result += _node_to_str(child, "", is_last)
return result

View File

@@ -0,0 +1,33 @@
# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
import argparse
import hashlib
import pathlib
import re
def update_wheel_index(cuda_version="118"):
index_dir = pathlib.Path(f"sgl-whl/cu{cuda_version}/sgl-kernel")
index_dir.mkdir(exist_ok=True)
base_url = "https://github.com/sgl-project/whl/releases/download"
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
with open(path, "rb") as f:
sha256 = hashlib.sha256(f.read()).hexdigest()
ver = re.findall(
r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+cu[0-9]+)?-", path.name
)[0]
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
with (index_dir / "index.html").open("a") as f:
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", type=str, default="118")
args = parser.parse_args()
update_wheel_index(args.cuda)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,30 @@
#!/bin/bash
set -euxo pipefail
# This script is used for release.
# It tags all remote branches starting with 'v' with the same name as the branch,
# deletes the corresponding branches from the remote, and pushes the tags to the remote repository.
git fetch origin --prune
# List all branches starting with 'v'
branches=$(git branch -r | grep 'origin/v' | sed 's/origin\///')
# Loop through each branch
for branch in $branches; do
echo "Processing branch: $branch"
# Get the commit hash for the branch
commit_hash=$(git rev-parse origin/$branch)
# Create a tag with the same name as the branch using the commit hash
git tag $branch $commit_hash
# Delete the branch from the remote
git push origin --delete $branch
done
# Push all tags to the remote repository
git push --tags
echo "All branches starting with 'v' have been tagged, deleted from remote, and pushed to the remote repository."