Let reward model take text inputs instead of message lists (#1907)
Co-authored-by: Kyle Corbitt <kyle@corbt.com>
This commit is contained in:
@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
|
||||
|
||||
class TestEmbeddingModels(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main()
|
||||
|
||||
@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestGenerationModels(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_logits_and_output_strs(
|
||||
self,
|
||||
|
||||
@@ -18,10 +18,10 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
|
||||
("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -43,6 +43,10 @@ CONVS = [
|
||||
|
||||
class TestRewardModels(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_reward_scores(
|
||||
self,
|
||||
convs,
|
||||
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="reward",
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(convs)
|
||||
prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
|
||||
srt_outputs = srt_runner.forward(prompts)
|
||||
|
||||
hf_scores = torch.tensor(hf_outputs.scores)
|
||||
srt_scores = torch.tensor(srt_outputs.scores)
|
||||
print(hf_scores)
|
||||
print(srt_scores)
|
||||
print(f"{hf_scores=}")
|
||||
print(f"{srt_scores=}")
|
||||
|
||||
assert torch.all(
|
||||
abs(hf_scores - srt_scores) < tolerance
|
||||
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,7 @@ suites = {
|
||||
"models/test_embedding_models.py",
|
||||
"models/test_generation_models.py",
|
||||
"models/test_lora.py",
|
||||
# "models/test_reward_models.py",
|
||||
"models/test_reward_models.py",
|
||||
"sampling/penaltylib",
|
||||
"test_chunked_prefill.py",
|
||||
"test_double_sparsity.py",
|
||||
|
||||
Reference in New Issue
Block a user