support internlm2-reward (#1994)
This commit is contained in:
@@ -210,6 +210,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
or "InternLM2ForRewardModel" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
|
||||
62
python/sglang/srt/models/internlm2_reward.py
Normal file
62
python/sglang/srt/models/internlm2_reward.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model
|
||||
|
||||
|
||||
class InternLM2ForRewardModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = InternLM2Model(config, quant_config)
|
||||
self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert get_embedding, "InternLM2ForRewardModel is only used for embedding"
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||
scores = self.v_head(last_token_hidden)
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
return InternLM2ForCausalLM.load_weights(self, weights)
|
||||
|
||||
|
||||
EntryClass = InternLM2ForRewardModel
|
||||
Reference in New Issue
Block a user