diff --git a/docs/advanced_features/deterministic_inference.md b/docs/advanced_features/deterministic_inference.md new file mode 100644 index 000000000..b5b6b5216 --- /dev/null +++ b/docs/advanced_features/deterministic_inference.md @@ -0,0 +1,154 @@ +# Deterministic Inference + +## Why Deterministic Inference Matters + +Deterministic inference ensures consistent LLM outputs across runs, which is critical for: +- **Reinforcement Learning**: Ensures consistent logprobs across runs, reducing stochastic noise and making RL training more stable, reproducible, and debuggable. +- **Testing & Debugging**: Enables reproducible validation +- **Production**: Improves reliability and user experience + +Even with `temperature=0`, standard LLM inference can produce different outputs due to dynamic batching and varying reduction orders in GPU kernels. + +## The Root Cause of Non-Determinism + +The main source is **varying batch sizes**. Different batch sizes cause GPU kernels to split reduction operations differently, leading to different addition orders. Due to floating-point non-associativity (`(a + b) + c ≠ a + (b + c)`), this produces different results even for identical inputs. + + +## SGLang's Solution + +Building on [Thinking Machines Lab's batch-invariant operators](https://github.com/thinking-machines-lab/batch_invariant_ops), SGLang achieves fully deterministic inference while maintaining compatibility with chunked prefill, CUDA graphs, radix cache, and non-greedy sampling. The development roadmap for deterministic inference features can be found in this [issue](https://github.com/sgl-project/sglang/issues/10278). + +### Supported Backends + +Deterministic inference is only supported with the following three attention backends: **FlashInfer**, **FlashAttention 3 (FA3)**, and **Triton**. + +The following table shows feature compatibility for deterministic inference across different attention backends: + +| Attention Backend | CUDA Graph | Chunked Prefill | Radix Cache | Non-greedy Sampling (Temp > 0) | +|-------------------|------------|-----------------|-------------|---------------------| +| **FlashInfer** | ✅ Yes | ✅ Yes | ❌ No | ✅ Yes | +| **FlashAttention 3 (FA3)** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes | +| **Triton** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes | + +## Usage + +### Basic Usage + +Enable deterministic inference by adding the `--enable-deterministic-inference` flag: + +```bash +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-8B \ + --attention-backend fa3 \ + --enable-deterministic-inference +``` + +### Server Arguments + +| Argument | Type/Default | Description | +|----------|--------------|-------------| +| `--enable-deterministic-inference` | flag; default: disabled | Enable deterministic inference with batch-invariant operations | +| `--attention-backend` | string; default: fa3 | Choose attention backend (flashinfer, fa3, or triton) | + +### Example Configurations + +#### Qwen3-8B +```bash +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-8B \ + --attention-backend flashinfer \ + --enable-deterministic-inference +``` + +#### Llama Models +```bash +python3 -m sglang.launch_server \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --attention-backend fa3 \ + --enable-deterministic-inference +``` + +#### Qwen3-30B-A3B (MoE Model) +```bash +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-30B-A3B \ + --attention-backend fa3 \ + --enable-deterministic-inference +``` + +### Deterministic Inference with Non-Greedy Sampling (Temperature > 0) + +SGLang supports deterministic inference even with non-greedy sampling by using sampling seeds. This is particularly useful for reinforcement learning scenarios like GRPO (Group Relative Policy Optimization) where you need multiple diverse but reproducible responses. + +#### Default Behavior + +By default, SGLang uses a sampling seed of `42` for reproducible sampling: + +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "Tell me a joke", + "sampling_params": { + "temperature": 0.8, # Non-greedy sampling + "max_new_tokens": 128, + }, + }, +) +print(response.json()) +# This will always produce the same response across runs +``` + +#### Generating Multiple Reproducible Responses + +To sample different responses from the same prompt while maintaining reproducibility (e.g., for GRPO training), provide different sampling seeds in your requests: + +```python +import requests + +# Prepare a list of sampling seeds for different responses +sampling_seeds = [42, 43, 44, 45, 46] + +responses = [] +for seed in sampling_seeds: + response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "Tell me a joke", + "sampling_params": { + "temperature": 0.8, + "max_new_tokens": 128, + "sampling_seed": seed, # Specify sampling seed + }, + }, + ) + responses.append(response.json()) + +# Each seed will produce a different but reproducible response +# Using the same seed will always produce the same response +``` + +This approach ensures that: +- Different seeds produce diverse responses +- The same seed always produces the same response across different runs +- Results are reproducible for debugging and evaluation + + +## Verification + +Run deterministic tests to verify consistent outputs: + +```bash +# Single test: same prompt, varying batch sizes +python3 -m sglang.test.test_deterministic --test-mode single --n-trials 50 + +# Prefix test: prompts with different prefix lengths +python3 -m sglang.test.test_deterministic --test-mode prefix --n-trials 50 + +# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill) +python3 -m sglang.test.test_deterministic --test-mode radix_cache +``` + +Expected result: All tests should show `Unique samples: 1` (perfectly deterministic). diff --git a/docs/index.rst b/docs/index.rst index 293f75984..4eb596394 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,6 +50,7 @@ Its core features include: advanced_features/pd_multiplexing.md advanced_features/vlm_query.ipynb advanced_features/router.md + advanced_features/deterministic_inference.md advanced_features/observability.md .. toctree::