[Model] Support Qwen2ForSequenceClassification (#4609)
Co-authored-by: ximing.wxm <ximing.wxm@antgroup.com>
This commit is contained in:
@@ -54,6 +54,8 @@
|
||||
- `python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code`
|
||||
- Qwen2ForRewardModel
|
||||
- `python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4`
|
||||
- Qwen2ForSequenceClassification
|
||||
- `python -m sglang.launch_server --model-path jason9693/Qwen2.5-1.5B-apeach --is-embedding --trust-remote-code`
|
||||
## How to Support a New Language Model
|
||||
|
||||
To support a new model in SGLang, you only need to add a single file under [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models).
|
||||
|
||||
@@ -453,6 +453,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
or "InternLM2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForSequenceClassification" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
|
||||
75
python/sglang/srt/models/qwen2_classification.py
Normal file
75
python/sglang/srt/models/qwen2_classification.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Qwen2ForSequenceClassification(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "Qwen2ForSequenceClassification is only used for embedding"
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
logits = self.score(hidden_states)
|
||||
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
||||
|
||||
return EmbeddingPoolerOutput(pooled_logits)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Filter out lm_head weights of Qwen2ForCausalLM
|
||||
filtered_weights = [
|
||||
(name, w) for name, w in weights if not name.startswith("lm_head")
|
||||
]
|
||||
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
|
||||
|
||||
|
||||
EntryClass = [
|
||||
Qwen2ForSequenceClassification,
|
||||
]
|
||||
@@ -13,18 +13,20 @@
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import get_similarities
|
||||
from sglang.test.test_utils import get_similarities, is_in_ci
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
|
||||
("marco/mcdse-2b-v1", 1, 1e-5),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase):
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size, prefill_tolerance in MODELS:
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
|
||||
|
||||
Reference in New Issue
Block a user