Let reward model take text inputs instead of message lists (#1907)
Co-authored-by: Kyle Corbitt <kyle@corbt.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
# return input_dim, output_dim
|
||||
|
||||
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
class MistralModel(LlamaEmbeddingModel):
|
||||
pass
|
||||
|
||||
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaForSequenceClassification is only used for embedding"
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
scores = self.score(hidden_states)
|
||||
|
||||
|
||||
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image == None:
|
||||
if image is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
Reference in New Issue
Block a user