[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)
This commit is contained in:
@@ -215,12 +215,11 @@ class EmbeddingReqInput:
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if self.text is not None:
|
||||
is_single = isinstance(self.text, str)
|
||||
self.is_single = isinstance(self.text, str)
|
||||
else:
|
||||
is_single = isinstance(self.input_ids[0], int)
|
||||
self.is_single = is_single
|
||||
self.is_single = isinstance(self.input_ids[0], int)
|
||||
|
||||
if is_single:
|
||||
if self.is_single:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardReqInput:
|
||||
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
||||
conv: Union[List[List[Dict]], List[Dict]]
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
|
||||
is_single: bool = True
|
||||
|
||||
def post_init(self):
|
||||
self.is_single = isinstance(self.conv[0], dict)
|
||||
|
||||
if self.is_single:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
self.sampling_params["max_new_tokens"] = 1
|
||||
else:
|
||||
# support select operation
|
||||
self.batch_size = len(self.conv)
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||
else:
|
||||
if not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list.")
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * self.batch_size
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedRewardReqInput:
|
||||
# The request id
|
||||
rid: str
|
||||
# The input text
|
||||
input_text: str
|
||||
# The input token ids
|
||||
input_ids: List[int]
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
# The request id
|
||||
|
||||
Reference in New Issue
Block a user