Let reward model take text inputs instead of message lists (#1907)
Co-authored-by: Kyle Corbitt <kyle@corbt.com>
This commit is contained in:
@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
|
|||||||
|
|
||||||
|
|
||||||
def get_context_length(config):
|
def get_context_length(config):
|
||||||
"""Get the context length of a model from a huggingface model configs.
|
"""Get the context length of a model from a huggingface model configs."""
|
||||||
And here the config should be text_config part if the model is a multimodal
|
text_config = config
|
||||||
LLM.
|
|
||||||
"""
|
|
||||||
text_config = getattr(config, "text_config", config)
|
|
||||||
rope_scaling = getattr(text_config, "rope_scaling", None)
|
rope_scaling = getattr(text_config, "rope_scaling", None)
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling_factor = rope_scaling.get("factor", 1)
|
rope_scaling_factor = rope_scaling.get("factor", 1)
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ class EmbeddingReqInput:
|
|||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = {}
|
self.sampling_params = {}
|
||||||
self.sampling_params["max_new_tokens"] = 1
|
self.sampling_params["max_new_tokens"] = 0
|
||||||
else:
|
else:
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||||
@@ -248,7 +248,7 @@ class EmbeddingReqInput:
|
|||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = [{}] * self.batch_size
|
self.sampling_params = [{}] * self.batch_size
|
||||||
for i in range(self.batch_size):
|
for i in range(self.batch_size):
|
||||||
self.sampling_params[i]["max_new_tokens"] = 1
|
self.sampling_params[i]["max_new_tokens"] = 0
|
||||||
|
|
||||||
def regenerate_rid(self):
|
def regenerate_rid(self):
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||||
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = False,
|
||||||
) -> LogitsProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
if not get_embedding:
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
return self.logits_processor(
|
||||||
)
|
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
|
|||||||
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
|
||||||
):
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
]
|
]
|
||||||
params_dict = dict(self.model.named_parameters())
|
params_dict = dict(self.model.named_parameters())
|
||||||
|
|
||||||
def load_weights_per_param(name, loaded_weight):
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
return
|
return
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
if name is None or loaded_weight is None:
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
load_weights_per_param(name, loaded_weight)
|
|
||||||
else:
|
|
||||||
load_weights_per_param(name, loaded_weight)
|
|
||||||
|
|
||||||
|
|
||||||
class MistralModel(LlamaEmbeddingModel):
|
class MistralModel(LlamaEmbeddingModel):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = True,
|
||||||
) -> EmbeddingPoolerOutput:
|
) -> EmbeddingPoolerOutput:
|
||||||
|
assert (
|
||||||
|
get_embedding
|
||||||
|
), "LlamaForSequenceClassification is only used for embedding"
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
scores = self.score(hidden_states)
|
scores = self.score(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||||
for i, image in enumerate(forward_batch.image_inputs):
|
for i, image in enumerate(forward_batch.image_inputs):
|
||||||
if image == None:
|
if image is None:
|
||||||
continue
|
continue
|
||||||
start_idx = extend_start_loc_cpu[i]
|
start_idx = extend_start_loc_cpu[i]
|
||||||
prefix_len = prefix_lens_cpu[i]
|
prefix_len = prefix_lens_cpu[i]
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
|
|||||||
|
|
||||||
|
|
||||||
async def judge_request(obj: EmbeddingReqInput, request: Request):
|
async def judge_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle a reward model request."""
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
@@ -696,24 +696,8 @@ class Runtime:
|
|||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
):
|
):
|
||||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
json_data = {"text": prompt}
|
||||||
# embedding
|
response = requests.post(self.url + "/encode", json=json_data)
|
||||||
json_data = {
|
|
||||||
"text": prompt,
|
|
||||||
}
|
|
||||||
response = requests.post(
|
|
||||||
self.url + "/encode",
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# reward
|
|
||||||
json_data = {
|
|
||||||
"conv": prompt,
|
|
||||||
}
|
|
||||||
response = requests.post(
|
|
||||||
self.url + "/judge",
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
return json.dumps(response.json())
|
return json.dumps(response.json())
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@@ -273,6 +273,7 @@ class SRTRunner:
|
|||||||
disable_cuda_graph=disable_cuda_graph,
|
disable_cuda_graph=disable_cuda_graph,
|
||||||
disable_radix_cache=disable_radix_cache,
|
disable_radix_cache=disable_radix_cache,
|
||||||
)
|
)
|
||||||
|
self.tokenizer = get_tokenizer(model_path)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -366,7 +367,7 @@ class SRTRunner:
|
|||||||
return ModelOutput(embed_logits=logits)
|
return ModelOutput(embed_logits=logits)
|
||||||
else:
|
else:
|
||||||
scores = [x["embedding"][0] for x in response]
|
scores = [x["embedding"][0] for x in response]
|
||||||
return ModelOutput(scores=logits)
|
return ModelOutput(scores=scores)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
|
|||||||
|
|
||||||
class TestEmbeddingModels(unittest.TestCase):
|
class TestEmbeddingModels(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
def assert_close_prefill_logits(
|
def assert_close_prefill_logits(
|
||||||
self,
|
self,
|
||||||
prompts,
|
prompts,
|
||||||
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
|
||||||
mp.set_start_method("spawn")
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
|
|||||||
|
|
||||||
|
|
||||||
class TestGenerationModels(unittest.TestCase):
|
class TestGenerationModels(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
def assert_close_logits_and_output_strs(
|
def assert_close_logits_and_output_strs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import HFRunner, SRTRunner
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
|
("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
|
||||||
]
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
@@ -43,6 +43,10 @@ CONVS = [
|
|||||||
|
|
||||||
class TestRewardModels(unittest.TestCase):
|
class TestRewardModels(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
def assert_close_reward_scores(
|
def assert_close_reward_scores(
|
||||||
self,
|
self,
|
||||||
convs,
|
convs,
|
||||||
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
model_type="reward",
|
model_type="reward",
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(convs)
|
prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
|
||||||
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
|
|
||||||
hf_scores = torch.tensor(hf_outputs.scores)
|
hf_scores = torch.tensor(hf_outputs.scores)
|
||||||
srt_scores = torch.tensor(srt_outputs.scores)
|
srt_scores = torch.tensor(srt_outputs.scores)
|
||||||
print(hf_scores)
|
print(f"{hf_scores=}")
|
||||||
print(srt_scores)
|
print(f"{srt_scores=}")
|
||||||
|
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
abs(hf_scores - srt_scores) < tolerance
|
abs(hf_scores - srt_scores) < tolerance
|
||||||
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
|
||||||
mp.set_start_method("spawn")
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ suites = {
|
|||||||
"models/test_embedding_models.py",
|
"models/test_embedding_models.py",
|
||||||
"models/test_generation_models.py",
|
"models/test_generation_models.py",
|
||||||
"models/test_lora.py",
|
"models/test_lora.py",
|
||||||
# "models/test_reward_models.py",
|
"models/test_reward_models.py",
|
||||||
"sampling/penaltylib",
|
"sampling/penaltylib",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user