From f18b9c72520dc403c6cc00d57321f499ca42803f Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 12 Nov 2024 07:09:58 +0800 Subject: [PATCH] support internlm2-reward (#1994) --- docs/references/supported_models.md | 2 + python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/models/internlm2_reward.py | 62 ++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 python/sglang/srt/models/internlm2_reward.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index ce178280b..3be83aaee 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -42,6 +42,8 @@ - `python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --is-embedding` - Gemma2ForSequenceClassification - `python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Gemma-2-27B-v0.2 --is-embedding` +- InternLM2ForRewardModel + - `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code` ## How to Support a New Model diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 2ce6d7459..b14d2c3d4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -210,6 +210,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal or "MistralModel" in model_architectures or "LlamaForSequenceClassification" in model_architectures or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures + or "InternLM2ForRewardModel" in model_architectures ): return False else: diff --git a/python/sglang/srt/models/internlm2_reward.py b/python/sglang/srt/models/internlm2_reward.py new file mode 100644 index 000000000..7ab6d034a --- /dev/null +++ b/python/sglang/srt/models/internlm2_reward.py @@ -0,0 +1,62 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model + + +class InternLM2ForRewardModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.model = InternLM2Model(config, quant_config) + self.v_head = nn.Linear(config.hidden_size, 1, bias=False) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = True, + ) -> EmbeddingPoolerOutput: + assert get_embedding, "InternLM2ForRewardModel is only used for embedding" + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings + scores = self.v_head(last_token_hidden) + return EmbeddingPoolerOutput(scores) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + return InternLM2ForCausalLM.load_weights(self, weights) + + +EntryClass = InternLM2ForRewardModel