Fix runtime.generate when sampling param is not passed (#1582)
This commit is contained in:
@@ -77,7 +77,7 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.parallel_sample_num = 1
|
self.parallel_sample_num = 1
|
||||||
if isinstance(self.sampling_params, dict):
|
elif isinstance(self.sampling_params, dict):
|
||||||
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||||
else: # isinstance(self.sampling_params, list):
|
else: # isinstance(self.sampling_params, list):
|
||||||
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user