diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml
index 49d649797..674efe86d 100644
--- a/.github/workflows/execute-notebook.yml
+++ b/.github/workflows/execute-notebook.yml
@@ -36,6 +36,9 @@ jobs:
run: |
bash scripts/ci_install_dependency.sh
pip install -r docs/requirements.txt
+ apt-get update
+ apt-get install -y pandoc
+ apt-get update && apt-get install -y parallel
- name: Setup Jupyter Kernel
run: |
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index 81d4349ee..69aa1ee7b 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -8,6 +8,7 @@ on:
- "python/sglang/**"
- "test/**"
- "docs/**"
+ - "scripts/**"
pull_request:
branches: [ main ]
paths:
@@ -15,6 +16,7 @@ on:
- "python/sglang/**"
- "test/**"
- "docs/**"
+ - "scripts/**"
workflow_dispatch:
inputs:
version:
@@ -45,6 +47,8 @@ jobs:
filters: |
docs:
- 'docs/**'
+ scripts:
+ - 'scripts/**'
sglang:
- 'python/sglang/**'
test:
diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml
index 37db70c7c..23d551460 100644
--- a/.github/workflows/release-docs.yml
+++ b/.github/workflows/release-docs.yml
@@ -32,6 +32,7 @@ jobs:
pip install -r docs/requirements.txt
apt-get update
apt-get install -y pandoc
+ apt-get update && apt-get install -y parallel
- name: Setup Jupyter Kernel
run: |
diff --git a/docs/Makefile b/docs/Makefile
index 13d81f4f8..36ad781d0 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -1,34 +1,42 @@
-# Minimal makefile for Sphinx documentation
-#
-
-# You can set these variables from the terminal, and also
-# from the environment for the first two.
+# Minimal Makefile for Sphinx documentation
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
-# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-# New target to compile Markdown and Jupyter Notebook files
+# Compile Notebook files and record execution time
compile:
- find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print | while read nb; do \
- if [ -f "$$nb" ]; then \
- echo "Executing $$nb"; \
- jupyter nbconvert --to notebook --execute --inplace "$$nb" \
- --ExecutePreprocessor.timeout=600 \
- --ExecutePreprocessor.kernel_name=python3 || exit 1; \
- fi; \
- done
+ @set -e; \
+ echo "Starting Notebook compilation..."; \
+ mkdir -p logs; \
+ echo "Notebook execution timings:" > logs/timing.log; \
+ START_TOTAL=$$(date +%s); \
+ find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print0 | \
+ parallel -0 -j3 --halt soon,fail=1 ' \
+ NB_NAME=$$(basename {}); \
+ START_TIME=$$(date +%s); \
+ jupyter nbconvert --to notebook --execute --inplace "{}" \
+ --ExecutePreprocessor.timeout=600 \
+ --ExecutePreprocessor.kernel_name=python3; \
+ RET_CODE=$$?; \
+ END_TIME=$$(date +%s); \
+ ELAPSED_TIME=$$((END_TIME - START_TIME)); \
+ echo "$${NB_NAME}: $${ELAPSED_TIME}s" >> logs/timing.log; \
+ exit $$RET_CODE' || exit 1; \
+ END_TOTAL=$$(date +%s); \
+ TOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \
+ echo "---------------------------------" >> logs/timing.log; \
+ echo "Total execution time: $${TOTAL_ELAPSED}s" >> logs/timing.log; \
+ echo "All Notebook execution timings:" && cat logs/timing.log
-.PHONY: help Makefile compile
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+.PHONY: help Makefile compile clean
+
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
clean:
- rm -rf $(BUILDDIR)/*
+ rm -rf $(BUILDDIR)/* logs/timing.log
diff --git a/docs/backend/__init__.py b/docs/backend/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb
index 05e7108e6..e80c91d54 100644
--- a/docs/backend/function_calling.ipynb
+++ b/docs/backend/function_calling.ipynb
@@ -31,17 +31,19 @@
"source": [
"from openai import OpenAI\n",
"import json\n",
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "from sglang.test.test_utils import is_in_ci\n",
"\n",
- "server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "\n",
+ "server_process, port = launch_server_cmd(\n",
+ " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n",
")\n",
- "wait_for_server(\"http://localhost:30333\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -141,7 +143,7 @@
"outputs": [],
"source": [
"# Initialize OpenAI-like client\n",
- "client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
+ "client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
"model_name = client.models.list().data[0].id"
]
},
@@ -377,13 +379,13 @@
" tools=tools,\n",
")\n",
"\n",
- "gen_url = \"http://localhost:30333/generate\"\n",
+ "gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"print(gen_response)\n",
"\n",
"# parse the response\n",
- "parse_url = \"http://localhost:30333/function_call\"\n",
+ "parse_url = f\"http://localhost:{port}/function_call\"\n",
"\n",
"function_call_input = {\n",
" \"text\": gen_response,\n",
@@ -403,7 +405,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(server_process)"
+ "terminate_process(server_process, port)"
]
},
{
diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb
index f6c10d745..2208bab35 100644
--- a/docs/backend/native_api.ipynb
+++ b/docs/backend/native_api.ipynb
@@ -34,22 +34,22 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
- "\n",
"import requests\n",
+ "from sglang.test.test_utils import is_in_ci\n",
"\n",
- "server_process = execute_shell_command(\n",
- " \"\"\"\n",
- "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n",
- "\"\"\"\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "\n",
+ "\n",
+ "server_process, port = launch_server_cmd(\n",
+ " \"python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30010\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -66,7 +66,7 @@
"metadata": {},
"outputs": [],
"source": [
- "url = \"http://localhost:30010/generate\"\n",
+ "url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -92,7 +92,7 @@
"metadata": {},
"outputs": [],
"source": [
- "url = \"http://localhost:30010/get_model_info\"\n",
+ "url = f\"http://localhost:{port}/get_model_info\"\n",
"\n",
"response = requests.get(url)\n",
"response_json = response.json()\n",
@@ -123,7 +123,7 @@
"source": [
"# get_server_info\n",
"\n",
- "url = \"http://localhost:30010/get_server_info\"\n",
+ "url = f\"http://localhost:{port}/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
@@ -144,7 +144,7 @@
"metadata": {},
"outputs": [],
"source": [
- "url = \"http://localhost:30010/health_generate\"\n",
+ "url = f\"http://localhost:{port}/health_generate\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
@@ -156,7 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
- "url = \"http://localhost:30010/health\"\n",
+ "url = f\"http://localhost:{port}/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
@@ -179,7 +179,7 @@
"source": [
"# flush cache\n",
"\n",
- "url = \"http://localhost:30010/flush_cache\"\n",
+ "url = f\"http://localhost:{port}/flush_cache\"\n",
"\n",
"response = requests.post(url)\n",
"print_highlight(response.text)"
@@ -204,7 +204,7 @@
"source": [
"# successful update with same architecture and size\n",
"\n",
- "url = \"http://localhost:30010/update_weights_from_disk\"\n",
+ "url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -222,7 +222,7 @@
"source": [
"# failed update with different parameter size or wrong name\n",
"\n",
- "url = \"http://localhost:30010/update_weights_from_disk\"\n",
+ "url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -252,16 +252,16 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(server_process)\n",
+ "terminate_process(server_process, port)\n",
"\n",
- "embedding_process = execute_shell_command(\n",
+ "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
- " --port 30020 --host 0.0.0.0 --is-embedding\n",
+ " --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30020\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -272,7 +272,7 @@
"source": [
"# successful encode for embedding model\n",
"\n",
- "url = \"http://localhost:30020/encode\"\n",
+ "url = f\"http://localhost:{port}/encode\"\n",
"data = {\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"text\": \"Once upon a time\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -280,6 +280,15 @@
"print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "terminate_process(embedding_process, port)"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -295,18 +304,18 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(embedding_process)\n",
+ "terminate_process(embedding_process, port)\n",
"\n",
"# Note that SGLang now treats embedding models and reward models as the same type of models.\n",
"# This will be updated in the future.\n",
"\n",
- "reward_process = execute_shell_command(\n",
+ "reward_process, port = launch_server_cmd(\n",
" \"\"\"\n",
- "python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --port 30030 --host 0.0.0.0 --is-embedding\n",
+ "python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30030\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -332,7 +341,7 @@
"tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n",
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n",
- "url = \"http://localhost:30030/classify\"\n",
+ "url = f\"http://localhost:{port}/classify\"\n",
"data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
"\n",
"responses = requests.post(url, json=data).json()\n",
@@ -346,7 +355,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(reward_process)"
+ "terminate_process(reward_process, port)"
]
},
{
@@ -364,13 +373,13 @@
"metadata": {},
"outputs": [],
"source": [
- "tokenizer_free_server_process = execute_shell_command(\n",
+ "tokenizer_free_server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
- "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n",
+ "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --skip-tokenizer-init\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30010\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -390,7 +399,7 @@
"print_highlight(f\"Tokenized Input: {input_tokens}\")\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30010/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"input_ids\": input_tokens,\n",
" \"sampling_params\": {\n",
@@ -416,7 +425,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(tokenizer_free_server_process)"
+ "terminate_process(tokenizer_free_server_process, port)"
]
}
],
diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb
index 3a2700856..aa2159f3c 100644
--- a/docs/backend/offline_engine_api.ipynb
+++ b/docs/backend/offline_engine_api.ipynb
@@ -40,6 +40,11 @@
"from sglang.utils import stream_and_merge, async_stream_and_merge\n",
"import sglang as sgl\n",
"import asyncio\n",
+ "from sglang.test.test_utils import is_in_ci\n",
+ "\n",
+ "if is_in_ci():\n",
+ " import patch\n",
+ "\n",
"\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")"
]
@@ -201,8 +206,6 @@
"metadata": {},
"outputs": [],
"source": [
- "import sglang as sgl\n",
- "\n",
"llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
")"
diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb
index 58b524108..e9c3f360a 100644
--- a/docs/backend/openai_api_completions.ipynb
+++ b/docs/backend/openai_api_completions.ipynb
@@ -33,18 +33,22 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
+ "from sglang.test.test_utils import is_in_ci\n",
+ "\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "\n",
+ "\n",
+ "server_process, port = launch_server_cmd(\n",
+ " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0\"\n",
")\n",
"\n",
- "server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n",
- ")\n",
- "\n",
- "wait_for_server(\"http://localhost:30020\")"
+ "wait_for_server(f\"http://localhost:{port}\")\n",
+ "print(f\"Server started on http://localhost:{port}\")"
]
},
{
@@ -68,7 +72,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
@@ -245,7 +249,7 @@
"import time\n",
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"requests = [\n",
" {\n",
@@ -348,10 +352,10 @@
"import time\n",
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"requests = []\n",
- "for i in range(100):\n",
+ "for i in range(20):\n",
" requests.append(\n",
" {\n",
" \"custom_id\": f\"request-{i}\",\n",
@@ -369,7 +373,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n",
" ],\n",
- " \"max_tokens\": 500,\n",
+ " \"max_tokens\": 64,\n",
" },\n",
" }\n",
" )\n",
@@ -425,10 +429,10 @@
"from openai import OpenAI\n",
"import os\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"requests = []\n",
- "for i in range(500):\n",
+ "for i in range(5000):\n",
" requests.append(\n",
" {\n",
" \"custom_id\": f\"request-{i}\",\n",
@@ -446,7 +450,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n",
" ],\n",
- " \"max_tokens\": 500,\n",
+ " \"max_tokens\": 128,\n",
" },\n",
" }\n",
" )\n",
@@ -508,7 +512,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(server_process)"
+ "terminate_process(server_process, port)"
]
}
],
diff --git a/docs/backend/openai_api_embeddings.ipynb b/docs/backend/openai_api_embeddings.ipynb
index 67ce68bcf..5a86ef18e 100644
--- a/docs/backend/openai_api_embeddings.ipynb
+++ b/docs/backend/openai_api_embeddings.ipynb
@@ -29,21 +29,23 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
+ "from sglang.test.test_utils import is_in_ci\n",
"\n",
- "embedding_process = execute_shell_command(\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "\n",
+ "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
- " --port 30000 --host 0.0.0.0 --is-embedding\n",
+ " --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30000\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -63,7 +65,7 @@
"\n",
"text = \"Once upon a time\"\n",
"\n",
- "curl_text = f\"\"\"curl -s http://localhost:30000/v1/embeddings \\\n",
+ "curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n",
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
@@ -91,7 +93,7 @@
"text = \"Once upon a time\"\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30000/v1/embeddings\",\n",
+ " f\"http://localhost:{port}/v1/embeddings\",\n",
" json={\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": text},\n",
")\n",
"\n",
@@ -115,7 +117,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"# Text embedding example\n",
"response = client.embeddings.create(\n",
@@ -151,7 +153,7 @@
"tokenizer = AutoTokenizer.from_pretrained(\"Alibaba-NLP/gte-Qwen2-7B-instruct\")\n",
"input_ids = tokenizer.encode(text)\n",
"\n",
- "curl_ids = f\"\"\"curl -s http://localhost:30000/v1/embeddings \\\n",
+ "curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
@@ -167,7 +169,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(embedding_process)"
+ "terminate_process(embedding_process, port)"
]
}
],
diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb
index da8864c24..2e8fdb5af 100644
--- a/docs/backend/openai_api_vision.ipynb
+++ b/docs/backend/openai_api_vision.ipynb
@@ -34,21 +34,23 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
+ "from sglang.test.test_utils import is_in_ci\n",
"\n",
- "embedding_process = execute_shell_command(\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "\n",
+ "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\\n",
- " --port=30000 --chat-template=llama_3_vision\n",
+ " --chat-template=llama_3_vision\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30000\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -68,32 +70,36 @@
"source": [
"import subprocess\n",
"\n",
- "curl_command = \"\"\"\n",
- "curl -s http://localhost:30000/v1/chat/completions \\\n",
- " -d '{\n",
+ "curl_command = f\"\"\"\n",
+ "curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
+ " -d '{{\n",
" \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
" \"messages\": [\n",
- " {\n",
+ " {{\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
- " {\n",
+ " {{\n",
" \"type\": \"text\",\n",
" \"text\": \"What’s in this image?\"\n",
- " },\n",
- " {\n",
+ " }},\n",
+ " {{\n",
" \"type\": \"image_url\",\n",
- " \"image_url\": {\n",
+ " \"image_url\": {{\n",
" \"url\": \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
- " }\n",
- " }\n",
+ " }}\n",
+ " }}\n",
" ]\n",
- " }\n",
+ " }}\n",
" ],\n",
" \"max_tokens\": 300\n",
- " }'\n",
+ " }}'\n",
"\"\"\"\n",
"\n",
"response = subprocess.check_output(curl_command, shell=True).decode()\n",
+ "print_highlight(response)\n",
+ "\n",
+ "\n",
+ "response = subprocess.check_output(curl_command, shell=True).decode()\n",
"print_highlight(response)"
]
},
@@ -112,7 +118,7 @@
"source": [
"import requests\n",
"\n",
- "url = \"http://localhost:30000/v1/chat/completions\"\n",
+ "url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n",
"data = {\n",
" \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
@@ -152,7 +158,7 @@
"source": [
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://localhost:30000/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=f\"http://localhost:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
@@ -196,7 +202,7 @@
"source": [
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://localhost:30000/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=f\"http://localhost:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
@@ -236,7 +242,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(embedding_process)"
+ "terminate_process(embedding_process, port)"
]
},
{
diff --git a/docs/backend/patch.py b/docs/backend/patch.py
new file mode 100644
index 000000000..1623c1b1f
--- /dev/null
+++ b/docs/backend/patch.py
@@ -0,0 +1,35 @@
+import os
+
+from sglang.utils import execute_shell_command, reserve_port
+
+DEFAULT_MAX_RUNNING_REQUESTS = 200
+DEFAULT_MAX_TOTAL_TOKENS = 20480
+
+import sglang.srt.server_args as server_args_mod
+
+_original_post_init = server_args_mod.ServerArgs.__post_init__
+
+
+def patched_post_init(self):
+ _original_post_init(self)
+ if self.max_running_requests is None:
+ self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
+ if self.max_total_tokens is None:
+ self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
+ self.disable_cuda_graph = True
+
+
+server_args_mod.ServerArgs.__post_init__ = patched_post_init
+
+
+def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
+ if port is None:
+ port = reserve_port()
+ extra_flags = (
+ f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
+ f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
+ f"--disable-cuda-graph"
+ )
+ full_command = f"{command} --port {port} {extra_flags}"
+ process = execute_shell_command(full_command)
+ return process, port
diff --git a/docs/start/send_request.ipynb b/docs/backend/send_request.ipynb
similarity index 82%
rename from docs/start/send_request.ipynb
rename to docs/backend/send_request.ipynb
index 4cb46f1ed..ea6398da1 100644
--- a/docs/start/send_request.ipynb
+++ b/docs/backend/send_request.ipynb
@@ -22,7 +22,7 @@
"\n",
"```bash\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
- "--port 30000 --host 0.0.0.0\n",
+ " --host 0.0.0.0\n",
"```\n",
"\n",
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the [OpenAI-compatible APIs](https://platform.openai.com/docs/api-reference/chat)."
@@ -34,21 +34,23 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
+ "from sglang.test.test_utils import is_in_ci\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
- "server_process = execute_shell_command(\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "\n",
+ "server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
- "--port 30000 --host 0.0.0.0\n",
+ " --host 0.0.0.0\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30000\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -66,9 +68,10 @@
"source": [
"import subprocess, json\n",
"\n",
- "curl_command = \"\"\"\n",
- "curl -s http://localhost:30000/v1/chat/completions \\\n",
- " -d '{\"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}]}'\n",
+ "curl_command = f\"\"\"\n",
+ "curl -s http://localhost:{port}/v1/chat/completions \\\n",
+ " -H \"Content-Type: application/json\" \\\n",
+ " -d '{{\"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"messages\": [{{\"role\": \"user\", \"content\": \"What is the capital of France?\"}}]}}'\n",
"\"\"\"\n",
"\n",
"response = json.loads(subprocess.check_output(curl_command, shell=True))\n",
@@ -90,7 +93,7 @@
"source": [
"import requests\n",
"\n",
- "url = \"http://localhost:30000/v1/chat/completions\"\n",
+ "url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n",
"data = {\n",
" \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
@@ -116,7 +119,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
@@ -144,7 +147,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"# Use stream=True for streaming responses\n",
"response = client.chat.completions.create(\n",
@@ -181,7 +184,7 @@
"import requests\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"The capital of France is\",\n",
" \"sampling_params\": {\n",
@@ -210,7 +213,7 @@
"import requests, json\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"The capital of France is\",\n",
" \"sampling_params\": {\n",
@@ -240,8 +243,15 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(server_process)"
+ "terminate_process(server_process, port)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb
index fb748901b..c7bd7fb31 100644
--- a/docs/backend/speculative_decoding.ipynb
+++ b/docs/backend/speculative_decoding.ipynb
@@ -35,23 +35,24 @@
"metadata": {},
"outputs": [],
"source": [
- "# EAGLE decoding\n",
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
+ "from sglang.test.test_utils import is_in_ci\n",
"\n",
- "server_process = execute_shell_command(\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
+ "\n",
+ "server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
- " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020 --cuda-graph-max-bs 32\n",
+ " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30020\")"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -62,7 +63,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
@@ -100,25 +101,16 @@
"metadata": {},
"outputs": [],
"source": [
- "server_process = execute_shell_command(\n",
+ "server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
- " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n",
- " --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n",
+ " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
+ " --enable-torch-compile --cuda-graph-max-bs 2\n",
"\"\"\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30020\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Benchmark Script\n",
- "\n",
- "The following code example shows how to measure the decoding speed when generating tokens:\n"
+ "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
@@ -127,27 +119,20 @@
"metadata": {},
"outputs": [],
"source": [
- "import time\n",
- "import requests\n",
+ "import openai\n",
"\n",
- "tic = time.time()\n",
- "response = requests.post(\n",
- " \"http://localhost:30020/generate\",\n",
- " json={\n",
- " \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n",
- " \"sampling_params\": {\n",
- " \"temperature\": 0,\n",
- " \"max_new_tokens\": 256,\n",
- " },\n",
- " },\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
+ "\n",
+ "response = client.chat.completions.create(\n",
+ " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " messages=[\n",
+ " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
+ " ],\n",
+ " temperature=0,\n",
+ " max_tokens=64,\n",
")\n",
- "latency = time.time() - tic\n",
- "ret = response.json()\n",
- "completion_text = ret[\"text\"]\n",
- "speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n",
"\n",
- "print_highlight(completion_text)\n",
- "print_highlight(f\"speed: {speed:.2f} token/s\")"
+ "print_highlight(f\"Response: {response}\")"
]
},
{
diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb
index e413743cc..31d59adf9 100644
--- a/docs/backend/structured_outputs.ipynb
+++ b/docs/backend/structured_outputs.ipynb
@@ -38,24 +38,26 @@
"metadata": {},
"outputs": [],
"source": [
- "from sglang.utils import (\n",
- " execute_shell_command,\n",
- " wait_for_server,\n",
- " terminate_process,\n",
- " print_highlight,\n",
- ")\n",
"import openai\n",
"import os\n",
+ "from sglang.test.test_utils import is_in_ci\n",
+ "\n",
+ "if is_in_ci():\n",
+ " from patch import launch_server_cmd\n",
+ "else:\n",
+ " from sglang.utils import launch_server_cmd\n",
+ "\n",
+ "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"\n",
- "server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n",
+ "server_process, port = launch_server_cmd(\n",
+ " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --grammar-backend xgrammar\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30000\")\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")"
+ "wait_for_server(f\"http://localhost:{port}\")\n",
+ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")"
]
},
{
@@ -264,7 +266,7 @@
"\n",
"# Make API request\n",
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n",
@@ -309,7 +311,7 @@
"\n",
"# JSON\n",
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n",
@@ -339,7 +341,7 @@
"import requests\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Give me the information of the capital of France.\",\n",
" \"sampling_params\": {\n",
@@ -376,7 +378,7 @@
"outputs": [],
"source": [
"response = requests.post(\n",
- " \"http://localhost:30000/generate\",\n",
+ " f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Paris is the capital of\",\n",
" \"sampling_params\": {\n",
@@ -395,7 +397,7 @@
"metadata": {},
"outputs": [],
"source": [
- "terminate_process(server_process)"
+ "terminate_process(server_process, port)"
]
},
{
diff --git a/docs/index.rst b/docs/index.rst
index 58f9731cb..59c317c20 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -16,13 +16,12 @@ The core features include:
:caption: Getting Started
start/install.md
- start/send_request.ipynb
-
.. toctree::
:maxdepth: 1
:caption: Backend Tutorial
+ backend/send_request.ipynb
backend/openai_api_completions.ipynb
backend/openai_api_vision.ipynb
backend/openai_api_embeddings.ipynb
@@ -33,7 +32,6 @@ The core features include:
backend/function_calling.ipynb
backend/server_arguments.md
-
.. toctree::
:maxdepth: 1
:caption: Frontend Tutorial
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 399427ef3..c747eb5b2 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -306,22 +306,94 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename
+import fcntl
+
+LOCKFILE = "/tmp/port_lock"
+PORT_REGISTRY = "/tmp/port_registry.json"
+
+
+def print_highlight(html_content: str):
+ html_content = str(html_content).replace("\n", "
")
+ display(HTML(f"{html_content}"))
+
+
+def init_port_registry():
+ """Initialize the port registry file if it doesn't exist."""
+ if not os.path.exists(PORT_REGISTRY):
+ with open(PORT_REGISTRY, "w") as f:
+ json.dump([], f)
+
+
+def reserve_port(start=30000, end=40000):
+ """
+ Reserve an available port using a file lock and a registry.
+ Returns the allocated port.
+ """
+ init_port_registry()
+ with open(LOCKFILE, "w") as lock:
+ fcntl.flock(lock, fcntl.LOCK_EX)
+ try:
+ with open(PORT_REGISTRY, "r") as f:
+ used = json.load(f)
+ except Exception:
+ used = []
+ for port in range(start, end):
+ if port not in used:
+ used.append(port)
+ with open(PORT_REGISTRY, "w") as f:
+ json.dump(used, f)
+ return port
+ raise RuntimeError("No free port available")
+
+
+def release_port(port):
+ """Release the reserved port by removing it from the registry."""
+ with open(LOCKFILE, "w") as lock:
+ fcntl.flock(lock, fcntl.LOCK_EX)
+ try:
+ with open(PORT_REGISTRY, "r") as f:
+ used = json.load(f)
+ except Exception:
+ used = []
+ if port in used:
+ used.remove(port)
+ with open(PORT_REGISTRY, "w") as f:
+ json.dump(used, f)
+
+
def execute_shell_command(command: str) -> subprocess.Popen:
"""
- Execute a shell command and return the process handle
-
- Args:
- command: Shell command as a string (can include \\ line continuations)
- Returns:
- subprocess.Popen: Process handle
+ Execute a shell command and return its process handle.
"""
- # Replace \ newline with space and split
+ # Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
-
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
+def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
+ """
+ Launch the server using the given command.
+ If no port is specified, a free port is reserved.
+ """
+ if port is None:
+ port = reserve_port()
+ full_command = f"{command} --port {port}"
+ process = execute_shell_command(full_command)
+ return process, port
+
+
+def terminate_process(process, port=None):
+ """
+ Terminate the process and, if a port was reserved, release it.
+ """
+ from sglang.srt.utils import kill_process_tree
+
+ kill_process_tree(process.pid)
+ if port is not None:
+ release_port(port)
+
+
def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
@@ -343,6 +415,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
+ We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
"""
)
break
@@ -353,17 +426,6 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
time.sleep(1)
-def terminate_process(process):
- from sglang.srt.utils import kill_process_tree
-
- kill_process_tree(process.pid)
-
-
-def print_highlight(html_content: str):
- html_content = str(html_content).replace("\n", "
")
- display(HTML(f"{html_content}"))
-
-
class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping
diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh
index 7cd6acd74..38440c848 100755
--- a/scripts/ci_install_dependency.sh
+++ b/scripts/ci_install_dependency.sh
@@ -3,6 +3,7 @@ set -euxo pipefail
# Install the dependency in CI.
+
# Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
diff --git a/scripts/export_deepseek_nextn.py b/scripts/export_deepseek_nextn.py
index 3e72fee4f..ad6bb0406 100644
--- a/scripts/export_deepseek_nextn.py
+++ b/scripts/export_deepseek_nextn.py
@@ -15,7 +15,7 @@ from safetensors.torch import save_file
from transformers import AutoConfig
-def get_nexn_layer_id(config):
+def get_nextn_layer_id(config):
if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.")
return config.num_hidden_layers
@@ -25,7 +25,7 @@ def update_and_save_config(config, output_dir):
new_config = config.to_dict()
new_config.update(
{
- "num_hidden_layers": 0,
+ "num_hidden_layers": 1,
"architectures": ["DeepseekV3ForCausalLMNextN"],
}
)
@@ -42,8 +42,8 @@ def copy_non_safetensors_files(input_dir, output_dir):
print(f"All non-safetensors files have been copied to {output_dir}")
-def export_nextn_layer_parameters(input_dir, output_dir, nexn_layer_id):
- prefix = f"model.layers.{nexn_layer_id}"
+def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
+ prefix = f"model.layers.{nextn_layer_id}"
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
params = {}
for filename in os.listdir(input_dir):
@@ -106,7 +106,7 @@ if __name__ == "__main__":
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
- nextn_layer_id = get_nexn_layer_id(config)
+ nextn_layer_id = get_nextn_layer_id(config)
os.makedirs(args.output_dir, exist_ok=True)
copy_non_safetensors_files(args.input_dir, args.output_dir)
update_and_save_config(config, args.output_dir)