[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)

This commit is contained in:
Ying Sheng
2024-09-27 23:32:11 -07:00
committed by GitHub
parent b1e330bcb0
commit 9aa6553d2a
13 changed files with 478 additions and 44 deletions

View File

@@ -215,12 +215,11 @@ class EmbeddingReqInput:
raise ValueError("Either text or input_ids should be provided.")
if self.text is not None:
is_single = isinstance(self.text, str)
self.is_single = isinstance(self.text, str)
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
self.is_single = isinstance(self.input_ids[0], int)
if is_single:
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams
@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
is_single: bool = True
def post_init(self):
self.is_single = isinstance(self.conv[0], dict)
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = len(self.conv)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass
class TokenizedRewardReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
@dataclass
class BatchTokenIDOut:
# The request id