[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)

This commit is contained in:
Ying Sheng
2024-09-27 23:32:11 -07:00
committed by GitHub
parent b1e330bcb0
commit 9aa6553d2a
13 changed files with 478 additions and 44 deletions

View File

@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
@@ -142,7 +144,7 @@ class TokenizerManager:
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
@@ -163,7 +165,7 @@ class TokenizerManager:
async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
@@ -173,7 +175,13 @@ class TokenizerManager:
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
assert self.tokenizer is not None
conv = obj.conv if not_use_index else obj.conv[index]
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text)
else:
@@ -269,13 +277,21 @@ class TokenizerManager:
else obj.lora_path
),
)
else: # is embedding
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
# Recv results
@@ -292,7 +308,7 @@ class TokenizerManager:
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
):
batch_size = obj.batch_size
@@ -329,9 +345,16 @@ class TokenizerManager:
rid = obj.rid[index]
if parallel_sample_num == 1:
## select operation
if obj.input_ids is None:
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[i]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
input_ids = self.tokenizer.encode(input_text)
else:
input_text = None
input_ids = obj.input_ids[i]
@@ -370,13 +393,21 @@ class TokenizerManager:
else obj.lora_path
),
)
else:
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
event = asyncio.Event()
@@ -442,7 +473,7 @@ class TokenizerManager:
async def _wait_for_response(
self,
state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput],
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
rid: str,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
@@ -469,7 +500,7 @@ class TokenizerManager:
),
obj.return_text_in_logprobs,
)
else: # isinstance(obj, EmbeddingReqInput)
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
out = state.out_list[-1]
out["index"] = response_index