fix sampling_seed handling when deterministic is enabled (#11096)

Signed-off-by: Alex Chi <iskyzh@gmail.com>
This commit is contained in:
Alex Chi Z
2025-10-03 23:41:46 -04:00
committed by GitHub
parent c70e58e837
commit d01b921482
2 changed files with 13 additions and 3 deletions

View File

@@ -142,6 +142,9 @@ class SamplingParams:
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}."
)
if self.sampling_seed is None:
raise ValueError("sampling_seed should not be None")
grammars = [
self.json_schema,
self.regex,

View File

@@ -96,12 +96,15 @@ def send_single(
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if args.sampling_seed is not None:
# sglang server cannot parse None value for sampling_seed
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
if profile:
run_profile(
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
@@ -145,12 +148,14 @@ def send_mixed(args, batch_size: int):
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if args.sampling_seed is not None:
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
@@ -192,12 +197,14 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if args.sampling_seed is not None:
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,