Let reward model take text inputs instead of message lists (#1907)
Co-authored-by: Kyle Corbitt <kyle@corbt.com>
This commit is contained in:
@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model configs.
|
||||
And here the config should be text_config part if the model is a multimodal
|
||||
LLM.
|
||||
"""
|
||||
text_config = getattr(config, "text_config", config)
|
||||
"""Get the context length of a model from a huggingface model configs."""
|
||||
text_config = config
|
||||
rope_scaling = getattr(text_config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = rope_scaling.get("factor", 1)
|
||||
|
||||
@@ -238,7 +238,7 @@ class EmbeddingReqInput:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
self.sampling_params["max_new_tokens"] = 1
|
||||
self.sampling_params["max_new_tokens"] = 0
|
||||
else:
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||
@@ -248,7 +248,7 @@ class EmbeddingReqInput:
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * self.batch_size
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 1
|
||||
self.sampling_params[i]["max_new_tokens"] = 0
|
||||
|
||||
def regenerate_rid(self):
|
||||
self.rid = uuid.uuid4().hex
|
||||
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
# return input_dim, output_dim
|
||||
|
||||
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
class MistralModel(LlamaEmbeddingModel):
|
||||
pass
|
||||
|
||||
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaForSequenceClassification is only used for embedding"
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
scores = self.score(hidden_states)
|
||||
|
||||
|
||||
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image == None:
|
||||
if image is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
async def judge_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request."""
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
@@ -696,24 +696,8 @@ class Runtime:
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
# embedding
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/encode",
|
||||
json=json_data,
|
||||
)
|
||||
else:
|
||||
# reward
|
||||
json_data = {
|
||||
"conv": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/judge",
|
||||
json=json_data,
|
||||
)
|
||||
json_data = {"text": prompt}
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -273,6 +273,7 @@ class SRTRunner:
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
self.tokenizer = get_tokenizer(model_path)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -366,7 +367,7 @@ class SRTRunner:
|
||||
return ModelOutput(embed_logits=logits)
|
||||
else:
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=logits)
|
||||
return ModelOutput(scores=scores)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user