[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)
This commit is contained in:
@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
RewardReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
@@ -142,7 +144,7 @@ class TokenizerManager:
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
@@ -163,7 +165,7 @@ class TokenizerManager:
|
||||
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
@@ -173,7 +175,13 @@ class TokenizerManager:
|
||||
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
if obj.input_ids is None:
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
assert self.tokenizer is not None
|
||||
conv = obj.conv if not_use_index else obj.conv[index]
|
||||
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
assert self.tokenizer is not None
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
@@ -269,13 +277,21 @@ class TokenizerManager:
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else: # is embedding
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
@@ -292,7 +308,7 @@ class TokenizerManager:
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
@@ -329,9 +345,16 @@ class TokenizerManager:
|
||||
rid = obj.rid[index]
|
||||
if parallel_sample_num == 1:
|
||||
## select operation
|
||||
if obj.input_ids is None:
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv[i]
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
@@ -370,13 +393,21 @@ class TokenizerManager:
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else:
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
@@ -442,7 +473,7 @@ class TokenizerManager:
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
state: ReqState,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
rid: str,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
@@ -469,7 +500,7 @@ class TokenizerManager:
|
||||
),
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, EmbeddingReqInput)
|
||||
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
|
||||
out = state.out_list[-1]
|
||||
|
||||
out["index"] = response_index
|
||||
|
||||
Reference in New Issue
Block a user