[Feature] Add a function to convert sampling_params to kwargs (#1170)

Co-authored-by: lzhang <zhanglei@modelbest.cn>
This commit is contained in:
rainred
2024-08-22 05:28:35 +08:00
committed by GitHub
parent 1fb9459908
commit d6aeb9fa15

View File

@@ -123,3 +123,17 @@ class SamplingParams:
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop_strs,
"stop_token_ids": list(self.stop_token_ids),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}