[2/2] Support deterministic inference for temperature > 0 (#10678)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -65,6 +65,7 @@ class Sampler(nn.Module):
|
||||
return_logprob: bool,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[List[int]],
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
||||
|
||||
@@ -77,6 +78,8 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
||||
compute output logprobs It is used for speculative decoding which
|
||||
performs sampling in draft workers.
|
||||
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
||||
to get the unique seed for each position.
|
||||
"""
|
||||
logits = logits_output.next_token_logits
|
||||
|
||||
@@ -124,6 +127,8 @@ class Sampler(nn.Module):
|
||||
sampling_info.top_ps,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
sampling_info.sampling_seed,
|
||||
positions,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -189,6 +194,7 @@ class Sampler(nn.Module):
|
||||
Optimized for prefill-only scoring requests that need token probabilities
|
||||
but don't require next token generation.
|
||||
"""
|
||||
|
||||
if logits_output.next_token_logits is None:
|
||||
logger.warning("No logits available for logprob computation")
|
||||
return
|
||||
@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
top_ps: torch.Tensor,
|
||||
min_ps: torch.Tensor,
|
||||
need_min_p_sampling: bool,
|
||||
sampling_seed: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||
"""
|
||||
A top-k, top-p and min-p sampling implementation with native pytorch operations.
|
||||
When sampling_seed is not None, deterministic inference will be enabled, it will sample
|
||||
with the sampling_seed of each request.
|
||||
"""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[
|
||||
@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
if need_min_p_sampling:
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
if sampling_seed is not None:
|
||||
sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
|
||||
else:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
# int32 range is enough to represent the token ids
|
||||
probs_idx = probs_idx.to(torch.int32)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def multinomial_with_seed(
|
||||
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Samples n elements from an input tensor `inputs` of shape (n, m) using
|
||||
a unique random seed for each row. This is a deterministic batched alternative to
|
||||
`torch.multinomial`.
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of shape (n, m) representing n categorical
|
||||
distributions with m categories each. The values are treated
|
||||
as weights and do not need to sum to 1.
|
||||
seed: An integer tensor of shape (n,) containing the random seed
|
||||
for each corresponding row in `inputs`.
|
||||
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
||||
to get the unique seed for each position.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (n,) where the i-th element is an index sampled
|
||||
from the distribution in `inputs[i]` using `seed[i]`.
|
||||
"""
|
||||
n, m = inputs.shape
|
||||
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
||||
step_seed = seed * 19349663 ^ positions * 73856093
|
||||
seed_expanded = step_seed.unsqueeze(-1)
|
||||
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
||||
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
||||
epsilon = 1e-9
|
||||
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
||||
log_probs = torch.log(inputs + epsilon)
|
||||
perturbed_log_probs = log_probs + gumbel_noise
|
||||
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
||||
|
||||
|
||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||
"""A sampling implementation with native pytorch operations, without
|
||||
top-k, top-p, or min-p filtering."""
|
||||
|
||||
@@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||
|
||||
|
||||
@@ -270,9 +270,7 @@ class TpModelWorker:
|
||||
logits_output, model_worker_batch
|
||||
)
|
||||
else:
|
||||
next_token_ids = self.model_runner.sample(
|
||||
logits_output, model_worker_batch
|
||||
)
|
||||
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||
|
||||
return logits_output, next_token_ids, can_run_cuda_graph
|
||||
else:
|
||||
|
||||
@@ -2049,7 +2049,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
@@ -2057,6 +2056,12 @@ class ModelRunner:
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
forward_batch.token_ids_logprobs,
|
||||
# For prefill, we only use the position of the last token.
|
||||
(
|
||||
forward_batch.positions
|
||||
if forward_batch.forward_mode.is_decode()
|
||||
else forward_batch.seq_lens - 1
|
||||
),
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
@@ -60,6 +60,9 @@ class SamplingBatchInfo:
|
||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||
] = None
|
||||
|
||||
# Used for deterministic sampling
|
||||
sampling_seed: Optional[torch.Tensor] = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@@ -93,6 +96,15 @@ class SamplingBatchInfo:
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
sampling_seed = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.sampling_seed for r in reqs],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
if enable_deterministic
|
||||
else None
|
||||
)
|
||||
|
||||
logit_bias = None
|
||||
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
||||
@@ -158,6 +170,7 @@ class SamplingBatchInfo:
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
sampling_seed=sampling_seed,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||
@@ -239,9 +252,11 @@ class SamplingBatchInfo:
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"sampling_seed",
|
||||
]:
|
||||
value = getattr(self, item, None)
|
||||
setattr(self, item, value[keep_indices_device])
|
||||
if value is not None:
|
||||
setattr(self, item, value[keep_indices_device])
|
||||
|
||||
if self.logit_bias is not None:
|
||||
self.logit_bias = self.logit_bias[keep_indices_device]
|
||||
@@ -343,10 +358,12 @@ class SamplingBatchInfo:
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"sampling_seed",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
if self_val is not None and other_val is not None:
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy &= other.is_all_greedy
|
||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||
|
||||
@@ -15,8 +15,11 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
TOP_K_ALL = 1 << 30
|
||||
DEFAULT_SAMPLING_SEED = 42
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
@@ -53,6 +56,7 @@ class SamplingParams:
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
stream_interval: Optional[int] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
sampling_seed: Optional[int] = None,
|
||||
) -> None:
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_strs = stop
|
||||
@@ -80,6 +84,14 @@ class SamplingParams:
|
||||
self.custom_params = custom_params
|
||||
self.stream_interval = stream_interval
|
||||
self.logit_bias = logit_bias
|
||||
# Used for deterministic sampling
|
||||
if (
|
||||
get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
|
||||
and sampling_seed is None
|
||||
):
|
||||
# If deterministic inference is enabled and sampling_seed is not set, use the default seed
|
||||
sampling_seed = DEFAULT_SAMPLING_SEED
|
||||
self.sampling_seed = sampling_seed
|
||||
|
||||
# Process some special cases
|
||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||
|
||||
@@ -988,6 +988,12 @@ class ServerArgs:
|
||||
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
||||
)
|
||||
|
||||
# Check some settings
|
||||
self.sampling_backend = "pytorch"
|
||||
logger.warning(
|
||||
"Sampling backend is set to pytorch for deterministic inference."
|
||||
)
|
||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||
if self.attention_backend != "fa3":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning(
|
||||
|
||||
@@ -29,6 +29,7 @@ class BenchArgs:
|
||||
port: int = 30000
|
||||
batch_size: int = 1
|
||||
temperature: float = 0.0
|
||||
sampling_seed: int = None
|
||||
max_new_tokens: int = 100
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
@@ -45,6 +46,9 @@ class BenchArgs:
|
||||
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
||||
parser.add_argument("--n-trials", type=int, default=50)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument(
|
||||
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
||||
)
|
||||
@@ -92,6 +96,7 @@ def send_single(
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
@@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int):
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
@@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
|
||||
Reference in New Issue
Block a user