[Model] Support Qwen2ForSequenceClassification (#4609)

Co-authored-by: ximing.wxm <ximing.wxm@antgroup.com>
This commit is contained in:
Ximingwang-09
2025-03-25 10:13:44 +08:00
committed by GitHub
parent 4c584fc632
commit 22c3702e1e
4 changed files with 87 additions and 2 deletions

View File

@@ -13,18 +13,20 @@
# ==============================================================================
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
from sglang.test.test_utils import get_similarities, is_in_ci
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
("marco/mcdse-2b-v1", 1, 1e-5),
("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16]
@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase):
), "embeddings are not all close"
def test_prefill_logits(self):
for model, tp_size, prefill_tolerance in MODELS:
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance