From e2ac7888b8cb1fd6c33a7ec58d27a5f5b5b24e0c Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Sun, 21 Sep 2025 19:36:08 -0700 Subject: [PATCH] [2/2] Support deterministic inference for temperature > 0 (#10678) Co-authored-by: Baizhou Zhang Co-authored-by: hebiao064 --- python/sglang/srt/layers/sampler.py | 56 +++++++++++++++++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/managers/tp_worker.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 7 ++- .../srt/sampling/sampling_batch_info.py | 21 ++++++- python/sglang/srt/sampling/sampling_params.py | 12 ++++ python/sglang/srt/server_args.py | 6 ++ python/sglang/test/test_deterministic.py | 7 +++ sgl-router/benches/request_processing.rs | 1 + sgl-router/src/protocols/spec.rs | 10 ++++ sgl-router/src/routers/http/openai_router.rs | 1 + sgl-router/tests/test_openai_routing.rs | 1 + 12 files changed, 117 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 03972c58b..47e66506d 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -65,6 +65,7 @@ class Sampler(nn.Module): return_logprob: bool, top_logprobs_nums: List[int], token_ids_logprobs: List[List[int]], + positions: torch.Tensor, ): """Run a sampler & compute logprobs and update logits_output accordingly. @@ -77,6 +78,8 @@ class Sampler(nn.Module): batch_next_token_ids: next token IDs. If set, skip sampling and only compute output logprobs It is used for speculative decoding which performs sampling in draft workers. + positions: The positions of the tokens in the sequence. Used for deterministic sampling + to get the unique seed for each position. """ logits = logits_output.next_token_logits @@ -124,6 +127,8 @@ class Sampler(nn.Module): sampling_info.top_ps, sampling_info.min_ps, sampling_info.need_min_p_sampling, + sampling_info.sampling_seed, + positions, ) else: raise ValueError( @@ -189,6 +194,7 @@ class Sampler(nn.Module): Optimized for prefill-only scoring requests that need token probabilities but don't require next token generation. """ + if logits_output.next_token_logits is None: logger.warning("No logits available for logprob computation") return @@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch( top_ps: torch.Tensor, min_ps: torch.Tensor, need_min_p_sampling: bool, + sampling_seed: Optional[torch.Tensor], + positions: torch.Tensor, ): - """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" + """ + A top-k, top-p and min-p sampling implementation with native pytorch operations. + When sampling_seed is not None, deterministic inference will be enabled, it will sample + with the sampling_seed of each request. + """ probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort[ @@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch( if need_min_p_sampling: min_p_thresholds = probs_sort[:, 0] * min_ps probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 - - sampled_index = torch.multinomial(probs_sort, num_samples=1) + if sampling_seed is not None: + sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions) + else: + sampled_index = torch.multinomial(probs_sort, num_samples=1) # int32 range is enough to represent the token ids probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids +def multinomial_with_seed( + inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor +) -> torch.Tensor: + """ + Samples n elements from an input tensor `inputs` of shape (n, m) using + a unique random seed for each row. This is a deterministic batched alternative to + `torch.multinomial`. + + Args: + inputs: A float tensor of shape (n, m) representing n categorical + distributions with m categories each. The values are treated + as weights and do not need to sum to 1. + seed: An integer tensor of shape (n,) containing the random seed + for each corresponding row in `inputs`. + positions: The positions of the tokens in the sequence. Used for deterministic sampling + to get the unique seed for each position. + + Returns: + A tensor of shape (n,) where the i-th element is an index sampled + from the distribution in `inputs[i]` using `seed[i]`. + """ + n, m = inputs.shape + col_indices = torch.arange(m, device=inputs.device).unsqueeze(0) + step_seed = seed * 19349663 ^ positions * 73856093 + seed_expanded = step_seed.unsqueeze(-1) + hashed = seed_expanded * 8589934591 ^ col_indices * 479001599 + uniform_samples = (hashed % (2**24)).float() / (2**24) + epsilon = 1e-9 + gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon) + log_probs = torch.log(inputs + epsilon) + perturbed_log_probs = log_probs + gumbel_noise + return torch.argmax(perturbed_log_probs, dim=1, keepdim=True) + + def sampling_from_probs_torch(probs: torch.Tensor): """A sampling implementation with native pytorch operations, without top-k, top-p, or min-p filtering.""" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d870f969a..f46e160cd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import flatten_nested_list, support_triton diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 98bc9a16f..6453b5675 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -270,9 +270,7 @@ class TpModelWorker: logits_output, model_worker_batch ) else: - next_token_ids = self.model_runner.sample( - logits_output, model_worker_batch - ) + next_token_ids = self.model_runner.sample(logits_output, forward_batch) return logits_output, next_token_ids, can_run_cuda_graph else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 441e7f465..1f08e43a1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2049,7 +2049,6 @@ class ModelRunner: ) self._preprocess_logits(logits_output, forward_batch.sampling_info) - # Sample the next tokens next_token_ids = self.sampler( logits_output, @@ -2057,6 +2056,12 @@ class ModelRunner: forward_batch.return_logprob, forward_batch.top_logprobs_nums, forward_batch.token_ids_logprobs, + # For prefill, we only use the position of the last token. + ( + forward_batch.positions + if forward_batch.forward_mode.is_decode() + else forward_batch.seq_lens - 1 + ), ) return next_token_ids diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7fb48a286..8d3e48bc2 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -60,6 +60,9 @@ class SamplingBatchInfo: Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] ] = None + # Used for deterministic sampling + sampling_seed: Optional[torch.Tensor] = None + # Device device: str = "cuda" @@ -93,6 +96,15 @@ class SamplingBatchInfo: min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device ) + sampling_seed = ( + torch.tensor( + [r.sampling_params.sampling_seed for r in reqs], + dtype=torch.int32, + device=device, + ) + if enable_deterministic + else None + ) logit_bias = None if any(r.sampling_params.logit_bias is not None for r in reqs): @@ -158,6 +170,7 @@ class SamplingBatchInfo: top_ps=top_ps, top_ks=top_ks, min_ps=min_ps, + sampling_seed=sampling_seed, is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs), need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs), @@ -239,9 +252,11 @@ class SamplingBatchInfo: "top_ps", "top_ks", "min_ps", + "sampling_seed", ]: value = getattr(self, item, None) - setattr(self, item, value[keep_indices_device]) + if value is not None: + setattr(self, item, value[keep_indices_device]) if self.logit_bias is not None: self.logit_bias = self.logit_bias[keep_indices_device] @@ -343,10 +358,12 @@ class SamplingBatchInfo: "top_ps", "top_ks", "min_ps", + "sampling_seed", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) - setattr(self, item, torch.cat([self_val, other_val])) + if self_val is not None and other_val is not None: + setattr(self, item, torch.cat([self_val, other_val])) self.is_all_greedy &= other.is_all_greedy self.need_top_p_sampling |= other.need_top_p_sampling diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index b7d1a6d6e..c644a9d7e 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -15,8 +15,11 @@ from typing import Any, Dict, List, Optional, Union +from sglang.srt.utils import get_bool_env_var + _SAMPLING_EPS = 1e-6 TOP_K_ALL = 1 << 30 +DEFAULT_SAMPLING_SEED = 42 class SamplingParams: @@ -53,6 +56,7 @@ class SamplingParams: custom_params: Optional[Dict[str, Any]] = None, stream_interval: Optional[int] = None, logit_bias: Optional[Dict[str, float]] = None, + sampling_seed: Optional[int] = None, ) -> None: self.max_new_tokens = max_new_tokens self.stop_strs = stop @@ -80,6 +84,14 @@ class SamplingParams: self.custom_params = custom_params self.stream_interval = stream_interval self.logit_bias = logit_bias + # Used for deterministic sampling + if ( + get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE") + and sampling_seed is None + ): + # If deterministic inference is enabled and sampling_seed is not set, use the default seed + sampling_seed = DEFAULT_SAMPLING_SEED + self.sampling_seed = sampling_seed # Process some special cases if 0 <= self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b1e474f77..099b2df8c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -988,6 +988,12 @@ class ServerArgs: "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." ) + # Check some settings + self.sampling_backend = "pytorch" + logger.warning( + "Sampling backend is set to pytorch for deterministic inference." + ) + # Currently, only FA3 supports radix cache. Support for other backends is in progress if self.attention_backend != "fa3": self.disable_radix_cache = True logger.warning( diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py index 7404d201f..8c4e45c7c 100644 --- a/python/sglang/test/test_deterministic.py +++ b/python/sglang/test/test_deterministic.py @@ -29,6 +29,7 @@ class BenchArgs: port: int = 30000 batch_size: int = 1 temperature: float = 0.0 + sampling_seed: int = None max_new_tokens: int = 100 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 @@ -45,6 +46,9 @@ class BenchArgs: parser.add_argument("--port", type=int, default=BenchArgs.port) parser.add_argument("--n-trials", type=int, default=50) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) + parser.add_argument( + "--sampling-seed", type=int, default=BenchArgs.sampling_seed + ) parser.add_argument( "--max-new-tokens", type=int, default=BenchArgs.max_new_tokens ) @@ -92,6 +96,7 @@ def send_single( "max_new_tokens": args.max_new_tokens, "frequency_penalty": args.frequency_penalty, "presence_penalty": args.presence_penalty, + "sampling_seed": args.sampling_seed, }, "return_logprob": args.return_logprob, "stream": args.stream, @@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int): "max_new_tokens": args.max_new_tokens, "frequency_penalty": args.frequency_penalty, "presence_penalty": args.presence_penalty, + "sampling_seed": args.sampling_seed, }, "return_logprob": args.return_logprob, "stream": args.stream, @@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]): "max_new_tokens": args.max_new_tokens, "frequency_penalty": args.frequency_penalty, "presence_penalty": args.presence_penalty, + "sampling_seed": args.sampling_seed, }, "return_logprob": args.return_logprob, "stream": args.stream, diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 3d2d55713..5c8aa389d 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -97,6 +97,7 @@ fn default_completion_request() -> CompletionRequest { lora_path: None, session_params: None, return_hidden_states: false, + sampling_seed: None, other: serde_json::Map::new(), } } diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index cb1f0a992..8e1d483ae 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -367,6 +367,10 @@ pub struct ChatCompletionRequest { /// Return model hidden states #[serde(default)] pub return_hidden_states: bool, + + /// Random seed for sampling for deterministic outputs + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, } impl GenerationRequest for ChatCompletionRequest { @@ -608,6 +612,10 @@ pub struct CompletionRequest { #[serde(default)] pub return_hidden_states: bool, + /// Sampling seed for deterministic outputs + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, + /// Additional fields including bootstrap info for PD routing #[serde(flatten)] pub other: serde_json::Map, @@ -1749,6 +1757,8 @@ pub struct SamplingParams { pub stop_token_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub no_stop_trim: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 013ee620f..4ed7cd631 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -240,6 +240,7 @@ impl super::super::RouterTrait for OpenAIRouter { "chat_template_kwargs", "return_hidden_states", "repetition_penalty", + "sampling_seed", ] { obj.remove(key); } diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 366c455f8..cfa12389f 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -68,6 +68,7 @@ fn create_minimal_completion_request() -> CompletionRequest { lora_path: None, session_params: None, return_hidden_states: false, + sampling_seed: None, other: serde_json::Map::new(), } }