fix sampling_seed handling when deterministic is enabled (#11096)
Signed-off-by: Alex Chi <iskyzh@gmail.com>
This commit is contained in:
@@ -142,6 +142,9 @@ class SamplingParams:
|
||||
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
|
||||
f"{token_id}."
|
||||
)
|
||||
if self.sampling_seed is None:
|
||||
raise ValueError("sampling_seed should not be None")
|
||||
|
||||
grammars = [
|
||||
self.json_schema,
|
||||
self.regex,
|
||||
|
||||
@@ -96,12 +96,15 @@ def send_single(
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
}
|
||||
|
||||
if args.sampling_seed is not None:
|
||||
# sglang server cannot parse None value for sampling_seed
|
||||
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||
|
||||
if profile:
|
||||
run_profile(
|
||||
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
@@ -145,12 +148,14 @@ def send_mixed(args, batch_size: int):
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
}
|
||||
|
||||
if args.sampling_seed is not None:
|
||||
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
@@ -192,12 +197,14 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"sampling_seed": args.sampling_seed,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
}
|
||||
|
||||
if args.sampling_seed is not None:
|
||||
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
|
||||
Reference in New Issue
Block a user