Let reward model take text inputs instead of message lists (#1907)

Co-authored-by: Kyle Corbitt <kyle@corbt.com>
This commit is contained in:
Lianmin Zheng
2024-11-03 13:27:12 -08:00
committed by GitHub
parent 793b79dbe9
commit 2ce32db6fb
12 changed files with 43 additions and 58 deletions

View File

@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
class TestEmbeddingModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits(
self,
prompts,
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main()

View File

@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
mp.set_start_method("spawn", force=True)
def assert_close_logits_and_output_strs(
self,

View File

@@ -18,10 +18,10 @@ import unittest
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.runners import HFRunner, SRTRunner
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
]
TORCH_DTYPES = [torch.float16]
@@ -43,6 +43,10 @@ CONVS = [
class TestRewardModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_reward_scores(
self,
convs,
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
torch_dtype=torch_dtype,
model_type="reward",
) as srt_runner:
srt_outputs = srt_runner.forward(convs)
prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
srt_outputs = srt_runner.forward(prompts)
hf_scores = torch.tensor(hf_outputs.scores)
srt_scores = torch.tensor(srt_outputs.scores)
print(hf_scores)
print(srt_scores)
print(f"{hf_scores=}")
print(f"{srt_scores=}")
assert torch.all(
abs(hf_scores - srt_scores) < tolerance
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main()

View File

@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py",
"models/test_generation_models.py",
"models/test_lora.py",
# "models/test_reward_models.py",
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",
"test_double_sparsity.py",