From 20b765a26e39f56ed24721f499c132c0dcd14c7c Mon Sep 17 00:00:00 2001 From: simveit <69345428+simveit@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:38:21 +0100 Subject: [PATCH] Model: Support Qwen 72B RM model. (#3772) --- docs/references/supported_models.md | 3 +- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/models/qwen2.py | 2 + python/sglang/srt/models/qwen2_rm.py | 70 +++++++++++++++++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/models/qwen2_rm.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index f8b8306bc..43ba4c5fb 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -47,7 +47,8 @@ - `python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Gemma-2-27B-v0.2 --is-embedding` - InternLM2ForRewardModel - `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code` - +- Qwen2ForRewardModel + - `python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4` ## How to Support a New Language Model To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b4653dc87..b0fc512c1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -389,6 +389,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal or "LlamaForSequenceClassification" in model_architectures or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures or "InternLM2ForRewardModel" in model_architectures + or "Qwen2ForRewardModel" in model_architectures ): return False else: diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 46b62f837..4afd9f2a3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -379,6 +379,8 @@ class Qwen2ForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + if name.startswith("lm_head"): + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/qwen2_rm.py b/python/sglang/srt/models/qwen2_rm.py new file mode 100644 index 000000000..c7aaa7697 --- /dev/null +++ b/python/sglang/srt/models/qwen2_rm.py @@ -0,0 +1,70 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model + + +class Qwen2ForRewardModel(nn.Module): + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.num_labels = 1 + self.model = Qwen2Model(config, quant_config=quant_config) + self.score = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size), + nn.ReLU(), + nn.Linear(config.hidden_size, self.num_labels), + ) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) + + self.eos_token_id = config.eos_token_id + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = True, + ) -> EmbeddingPoolerOutput: + assert get_embedding, "Qwen2ForRewardModel is only used for embedding" + + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + logits = self.score(hidden_states) + pooled_logits = self.pooler(logits, forward_batch).embeddings + + return EmbeddingPoolerOutput(pooled_logits) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + return Qwen2ForCausalLM.load_weights(self, weights) + + +EntryClass = [ + Qwen2ForRewardModel, +]