diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml index e03edd6ce..49d649797 100644 --- a/.github/workflows/execute-notebook.yml +++ b/.github/workflows/execute-notebook.yml @@ -42,7 +42,7 @@ jobs: python -m ipykernel install --user --name python3 --display-name "Python 3" - name: Execute notebooks - timeout-minutes: 30 + timeout-minutes: 40 run: | cd docs make clean diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 3de80aadf..05e7108e6 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -507,7 +507,15 @@ ], "metadata": { "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 7ce89d435..58d24ac3f 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "# launch the offline engine\n", - "\n", + "from sglang.utils import stream_and_merge, async_stream_and_merge\n", "import sglang as sgl\n", "import asyncio\n", "\n", @@ -86,20 +86,22 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", - "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing synchronous streaming generation ===\")\n", + "sampling_params = {\n", + " \"temperature\": 0.2,\n", + " \"top_p\": 0.9,\n", + "}\n", + "\n", + "print(\"\\n=== Testing synchronous streaming generation with overlap removal ===\\n\")\n", "\n", "for prompt in prompts:\n", - " print(f\"\\nPrompt: {prompt}\")\n", - " print(\"Generated text: \", end=\"\", flush=True)\n", - "\n", - " for chunk in llm.generate(prompt, sampling_params, stream=True):\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", + " print(f\"Prompt: {prompt}\")\n", + " merged_output = stream_and_merge(llm, prompt, sampling_params)\n", + " print(\"Generated text:\", merged_output)\n", " print()" ] }, @@ -117,9 +119,9 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", @@ -152,13 +154,14 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", + "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing asynchronous streaming generation ===\")\n", + "print(\"\\n=== Testing asynchronous streaming generation (no repeats) ===\")\n", "\n", "\n", "async def main():\n", @@ -166,10 +169,11 @@ " print(f\"\\nPrompt: {prompt}\")\n", " print(\"Generated text: \", end=\"\", flush=True)\n", "\n", - " generator = await llm.async_generate(prompt, sampling_params, stream=True)\n", - " async for chunk in generator:\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", - " print()\n", + " # Replace direct calls to async_generate with our custom overlap-aware version\n", + " async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):\n", + " print(cleaned_chunk, end=\"\", flush=True)\n", + "\n", + " print() # New line after each prompt\n", "\n", "\n", "asyncio.run(main())" diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 391050a0d..d69436eed 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -8,12 +8,17 @@ "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", + "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", + "> ```bash\n", + "> pip install cutex\n", + "> ```\n", + "\n", "### Performance Highlights\n", "\n", - "- **Official EAGLE code** ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", - "- **Standard SGLang Decoding**: ~156 tokens/s\n", - "- **EAGLE Decoding in SGLang**: ~297 tokens/s\n", - "- **EAGLE Decoding in SGLang (w/ `torch.compile`)**: ~316 tokens/s\n", + "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", + "- Standard SGLang Decoding: ~156 tokens/s\n", + "- EAGLE Decoding in SGLang: ~297 tokens/s\n", + "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n", "\n", "All benchmarks below were run on a single H100." ] diff --git a/docs/start/install.md b/docs/start/install.md index bd39947a1..90964ac6b 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -5,6 +5,7 @@ You can install SGLang using any of the methods below. ## Method 1: With pip ``` pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` @@ -17,10 +18,11 @@ git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` -Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: @@ -30,6 +32,7 @@ git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all_hip]" ``` diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 742eebc3b..399427ef3 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -373,3 +373,45 @@ class TypeBasedDispatcher: if isinstance(obj, ty): return fn(obj) raise ValueError(f"Invalid object: {obj}") + + +def trim_overlap(existing_text, new_chunk): + """ + Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk' + and removes that overlap from the start of 'new_chunk'. + """ + max_overlap = 0 + max_possible = min(len(existing_text), len(new_chunk)) + for i in range(max_possible, 0, -1): + if existing_text.endswith(new_chunk[:i]): + max_overlap = i + break + return new_chunk[max_overlap:] + + +def stream_and_merge(llm, prompt, sampling_params): + """ + 1) Streams the text, + 2) Removes chunk overlaps, + 3) Returns the merged text. + """ + final_text = "" + for chunk in llm.generate(prompt, sampling_params, stream=True): + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + return final_text + + +async def async_stream_and_merge(llm, prompt, sampling_params): + """ + Streams tokens asynchronously, removes chunk overlaps, + and yields the cleaned chunk in real time for printing. + """ + final_text = "" + generator = await llm.async_generate(prompt, sampling_params, stream=True) + async for chunk in generator: + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + yield cleaned_chunk # yield the non-overlapping portion