[2/2] Support deterministic inference for temperature > 0 (#10678)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -65,6 +65,7 @@ class Sampler(nn.Module):
|
|||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
top_logprobs_nums: List[int],
|
top_logprobs_nums: List[int],
|
||||||
token_ids_logprobs: List[List[int]],
|
token_ids_logprobs: List[List[int]],
|
||||||
|
positions: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
||||||
|
|
||||||
@@ -77,6 +78,8 @@ class Sampler(nn.Module):
|
|||||||
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
||||||
compute output logprobs It is used for speculative decoding which
|
compute output logprobs It is used for speculative decoding which
|
||||||
performs sampling in draft workers.
|
performs sampling in draft workers.
|
||||||
|
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
||||||
|
to get the unique seed for each position.
|
||||||
"""
|
"""
|
||||||
logits = logits_output.next_token_logits
|
logits = logits_output.next_token_logits
|
||||||
|
|
||||||
@@ -124,6 +127,8 @@ class Sampler(nn.Module):
|
|||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
sampling_info.min_ps,
|
sampling_info.min_ps,
|
||||||
sampling_info.need_min_p_sampling,
|
sampling_info.need_min_p_sampling,
|
||||||
|
sampling_info.sampling_seed,
|
||||||
|
positions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -189,6 +194,7 @@ class Sampler(nn.Module):
|
|||||||
Optimized for prefill-only scoring requests that need token probabilities
|
Optimized for prefill-only scoring requests that need token probabilities
|
||||||
but don't require next token generation.
|
but don't require next token generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if logits_output.next_token_logits is None:
|
if logits_output.next_token_logits is None:
|
||||||
logger.warning("No logits available for logprob computation")
|
logger.warning("No logits available for logprob computation")
|
||||||
return
|
return
|
||||||
@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
top_ps: torch.Tensor,
|
top_ps: torch.Tensor,
|
||||||
min_ps: torch.Tensor,
|
min_ps: torch.Tensor,
|
||||||
need_min_p_sampling: bool,
|
need_min_p_sampling: bool,
|
||||||
|
sampling_seed: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
"""
|
||||||
|
A top-k, top-p and min-p sampling implementation with native pytorch operations.
|
||||||
|
When sampling_seed is not None, deterministic inference will be enabled, it will sample
|
||||||
|
with the sampling_seed of each request.
|
||||||
|
"""
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
probs_sort[
|
probs_sort[
|
||||||
@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
if need_min_p_sampling:
|
if need_min_p_sampling:
|
||||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||||
|
if sampling_seed is not None:
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
|
||||||
|
else:
|
||||||
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
# int32 range is enough to represent the token ids
|
# int32 range is enough to represent the token ids
|
||||||
probs_idx = probs_idx.to(torch.int32)
|
probs_idx = probs_idx.to(torch.int32)
|
||||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_with_seed(
|
||||||
|
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Samples n elements from an input tensor `inputs` of shape (n, m) using
|
||||||
|
a unique random seed for each row. This is a deterministic batched alternative to
|
||||||
|
`torch.multinomial`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A float tensor of shape (n, m) representing n categorical
|
||||||
|
distributions with m categories each. The values are treated
|
||||||
|
as weights and do not need to sum to 1.
|
||||||
|
seed: An integer tensor of shape (n,) containing the random seed
|
||||||
|
for each corresponding row in `inputs`.
|
||||||
|
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
||||||
|
to get the unique seed for each position.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (n,) where the i-th element is an index sampled
|
||||||
|
from the distribution in `inputs[i]` using `seed[i]`.
|
||||||
|
"""
|
||||||
|
n, m = inputs.shape
|
||||||
|
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
||||||
|
step_seed = seed * 19349663 ^ positions * 73856093
|
||||||
|
seed_expanded = step_seed.unsqueeze(-1)
|
||||||
|
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
||||||
|
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
||||||
|
epsilon = 1e-9
|
||||||
|
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
||||||
|
log_probs = torch.log(inputs + epsilon)
|
||||||
|
perturbed_log_probs = log_probs + gumbel_noise
|
||||||
|
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||||
"""A sampling implementation with native pytorch operations, without
|
"""A sampling implementation with native pytorch operations, without
|
||||||
top-k, top-p, or min-p filtering."""
|
top-k, top-p, or min-p filtering."""
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||||
|
|
||||||
|
|||||||
@@ -270,9 +270,7 @@ class TpModelWorker:
|
|||||||
logits_output, model_worker_batch
|
logits_output, model_worker_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids = self.model_runner.sample(
|
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||||
logits_output, model_worker_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits_output, next_token_ids, can_run_cuda_graph
|
return logits_output, next_token_ids, can_run_cuda_graph
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2049,7 +2049,6 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||||
|
|
||||||
# Sample the next tokens
|
# Sample the next tokens
|
||||||
next_token_ids = self.sampler(
|
next_token_ids = self.sampler(
|
||||||
logits_output,
|
logits_output,
|
||||||
@@ -2057,6 +2056,12 @@ class ModelRunner:
|
|||||||
forward_batch.return_logprob,
|
forward_batch.return_logprob,
|
||||||
forward_batch.top_logprobs_nums,
|
forward_batch.top_logprobs_nums,
|
||||||
forward_batch.token_ids_logprobs,
|
forward_batch.token_ids_logprobs,
|
||||||
|
# For prefill, we only use the position of the last token.
|
||||||
|
(
|
||||||
|
forward_batch.positions
|
||||||
|
if forward_batch.forward_mode.is_decode()
|
||||||
|
else forward_batch.seq_lens - 1
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ class SamplingBatchInfo:
|
|||||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
# Used for deterministic sampling
|
||||||
|
sampling_seed: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
@@ -93,6 +96,15 @@ class SamplingBatchInfo:
|
|||||||
min_ps = torch.tensor(
|
min_ps = torch.tensor(
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||||
)
|
)
|
||||||
|
sampling_seed = (
|
||||||
|
torch.tensor(
|
||||||
|
[r.sampling_params.sampling_seed for r in reqs],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if enable_deterministic
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
logit_bias = None
|
logit_bias = None
|
||||||
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
||||||
@@ -158,6 +170,7 @@ class SamplingBatchInfo:
|
|||||||
top_ps=top_ps,
|
top_ps=top_ps,
|
||||||
top_ks=top_ks,
|
top_ks=top_ks,
|
||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
|
sampling_seed=sampling_seed,
|
||||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||||
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||||
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||||
@@ -239,9 +252,11 @@ class SamplingBatchInfo:
|
|||||||
"top_ps",
|
"top_ps",
|
||||||
"top_ks",
|
"top_ks",
|
||||||
"min_ps",
|
"min_ps",
|
||||||
|
"sampling_seed",
|
||||||
]:
|
]:
|
||||||
value = getattr(self, item, None)
|
value = getattr(self, item, None)
|
||||||
setattr(self, item, value[keep_indices_device])
|
if value is not None:
|
||||||
|
setattr(self, item, value[keep_indices_device])
|
||||||
|
|
||||||
if self.logit_bias is not None:
|
if self.logit_bias is not None:
|
||||||
self.logit_bias = self.logit_bias[keep_indices_device]
|
self.logit_bias = self.logit_bias[keep_indices_device]
|
||||||
@@ -343,10 +358,12 @@ class SamplingBatchInfo:
|
|||||||
"top_ps",
|
"top_ps",
|
||||||
"top_ks",
|
"top_ks",
|
||||||
"min_ps",
|
"min_ps",
|
||||||
|
"sampling_seed",
|
||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
self_val = getattr(self, item, None)
|
||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
setattr(self, item, torch.cat([self_val, other_val]))
|
if self_val is not None and other_val is not None:
|
||||||
|
setattr(self, item, torch.cat([self_val, other_val]))
|
||||||
|
|
||||||
self.is_all_greedy &= other.is_all_greedy
|
self.is_all_greedy &= other.is_all_greedy
|
||||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||||
|
|||||||
@@ -15,8 +15,11 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-6
|
_SAMPLING_EPS = 1e-6
|
||||||
TOP_K_ALL = 1 << 30
|
TOP_K_ALL = 1 << 30
|
||||||
|
DEFAULT_SAMPLING_SEED = 42
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
@@ -53,6 +56,7 @@ class SamplingParams:
|
|||||||
custom_params: Optional[Dict[str, Any]] = None,
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
stream_interval: Optional[int] = None,
|
stream_interval: Optional[int] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
sampling_seed: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.stop_strs = stop
|
self.stop_strs = stop
|
||||||
@@ -80,6 +84,14 @@ class SamplingParams:
|
|||||||
self.custom_params = custom_params
|
self.custom_params = custom_params
|
||||||
self.stream_interval = stream_interval
|
self.stream_interval = stream_interval
|
||||||
self.logit_bias = logit_bias
|
self.logit_bias = logit_bias
|
||||||
|
# Used for deterministic sampling
|
||||||
|
if (
|
||||||
|
get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
|
||||||
|
and sampling_seed is None
|
||||||
|
):
|
||||||
|
# If deterministic inference is enabled and sampling_seed is not set, use the default seed
|
||||||
|
sampling_seed = DEFAULT_SAMPLING_SEED
|
||||||
|
self.sampling_seed = sampling_seed
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||||
|
|||||||
@@ -988,6 +988,12 @@ class ServerArgs:
|
|||||||
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check some settings
|
||||||
|
self.sampling_backend = "pytorch"
|
||||||
|
logger.warning(
|
||||||
|
"Sampling backend is set to pytorch for deterministic inference."
|
||||||
|
)
|
||||||
|
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||||
if self.attention_backend != "fa3":
|
if self.attention_backend != "fa3":
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class BenchArgs:
|
|||||||
port: int = 30000
|
port: int = 30000
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
sampling_seed: int = None
|
||||||
max_new_tokens: int = 100
|
max_new_tokens: int = 100
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
@@ -45,6 +46,9 @@ class BenchArgs:
|
|||||||
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
||||||
parser.add_argument("--n-trials", type=int, default=50)
|
parser.add_argument("--n-trials", type=int, default=50)
|
||||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
||||||
)
|
)
|
||||||
@@ -92,6 +96,7 @@ def send_single(
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
|
"sampling_seed": args.sampling_seed,
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
@@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int):
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
|
"sampling_seed": args.sampling_seed,
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
@@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
|
"sampling_seed": args.sampling_seed,
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ fn default_completion_request() -> CompletionRequest {
|
|||||||
lora_path: None,
|
lora_path: None,
|
||||||
session_params: None,
|
session_params: None,
|
||||||
return_hidden_states: false,
|
return_hidden_states: false,
|
||||||
|
sampling_seed: None,
|
||||||
other: serde_json::Map::new(),
|
other: serde_json::Map::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -367,6 +367,10 @@ pub struct ChatCompletionRequest {
|
|||||||
/// Return model hidden states
|
/// Return model hidden states
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub return_hidden_states: bool,
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
|
/// Random seed for sampling for deterministic outputs
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sampling_seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerationRequest for ChatCompletionRequest {
|
impl GenerationRequest for ChatCompletionRequest {
|
||||||
@@ -608,6 +612,10 @@ pub struct CompletionRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub return_hidden_states: bool,
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
|
/// Sampling seed for deterministic outputs
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sampling_seed: Option<u64>,
|
||||||
|
|
||||||
/// Additional fields including bootstrap info for PD routing
|
/// Additional fields including bootstrap info for PD routing
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub other: serde_json::Map<String, serde_json::Value>,
|
pub other: serde_json::Map<String, serde_json::Value>,
|
||||||
@@ -1749,6 +1757,8 @@ pub struct SamplingParams {
|
|||||||
pub stop_token_ids: Option<Vec<i32>>,
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub no_stop_trim: Option<bool>,
|
pub no_stop_trim: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sampling_seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
|||||||
@@ -240,6 +240,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
"chat_template_kwargs",
|
"chat_template_kwargs",
|
||||||
"return_hidden_states",
|
"return_hidden_states",
|
||||||
"repetition_penalty",
|
"repetition_penalty",
|
||||||
|
"sampling_seed",
|
||||||
] {
|
] {
|
||||||
obj.remove(key);
|
obj.remove(key);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
|||||||
lora_path: None,
|
lora_path: None,
|
||||||
session_params: None,
|
session_params: None,
|
||||||
return_hidden_states: false,
|
return_hidden_states: false,
|
||||||
|
sampling_seed: None,
|
||||||
other: serde_json::Map::new(),
|
other: serde_json::Map::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user