fix sampling_seed handling when deterministic is enabled (#11096)
Signed-off-by: Alex Chi <iskyzh@gmail.com>
This commit is contained in:
@@ -142,6 +142,9 @@ class SamplingParams:
|
|||||||
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
|
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
|
||||||
f"{token_id}."
|
f"{token_id}."
|
||||||
)
|
)
|
||||||
|
if self.sampling_seed is None:
|
||||||
|
raise ValueError("sampling_seed should not be None")
|
||||||
|
|
||||||
grammars = [
|
grammars = [
|
||||||
self.json_schema,
|
self.json_schema,
|
||||||
self.regex,
|
self.regex,
|
||||||
|
|||||||
@@ -96,12 +96,15 @@ def send_single(
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
"sampling_seed": args.sampling_seed,
|
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.sampling_seed is not None:
|
||||||
|
# sglang server cannot parse None value for sampling_seed
|
||||||
|
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
run_profile(
|
run_profile(
|
||||||
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||||
@@ -145,12 +148,14 @@ def send_mixed(args, batch_size: int):
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
"sampling_seed": args.sampling_seed,
|
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.sampling_seed is not None:
|
||||||
|
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"http://{args.host}:{args.port}/generate",
|
f"http://{args.host}:{args.port}/generate",
|
||||||
json=json_data,
|
json=json_data,
|
||||||
@@ -192,12 +197,14 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"frequency_penalty": args.frequency_penalty,
|
"frequency_penalty": args.frequency_penalty,
|
||||||
"presence_penalty": args.presence_penalty,
|
"presence_penalty": args.presence_penalty,
|
||||||
"sampling_seed": args.sampling_seed,
|
|
||||||
},
|
},
|
||||||
"return_logprob": args.return_logprob,
|
"return_logprob": args.return_logprob,
|
||||||
"stream": args.stream,
|
"stream": args.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.sampling_seed is not None:
|
||||||
|
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"http://{args.host}:{args.port}/generate",
|
f"http://{args.host}:{args.port}/generate",
|
||||||
json=json_data,
|
json=json_data,
|
||||||
|
|||||||
Reference in New Issue
Block a user