[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)

This commit is contained in:
Ying Sheng
2024-09-27 23:32:11 -07:00
committed by GitHub
parent b1e330bcb0
commit 9aa6553d2a
13 changed files with 478 additions and 44 deletions

View File

@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
RewardReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -213,6 +214,21 @@ app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
async def judge_request(obj: RewardReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/judge")(judge_request)
app.put("/judge")(judge_request)
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
@@ -635,15 +651,26 @@ class Runtime:
def encode(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
if isinstance(prompt, str) or isinstance(prompt[0], str):
# embedding
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
else:
# reward
json_data = {
"conv": prompt,
}
response = requests.post(
self.url + "/judge",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):