[Model] Support Qwen2ForSequenceClassification (#4609)
Co-authored-by: ximing.wxm <ximing.wxm@antgroup.com>
This commit is contained in:
@@ -13,18 +13,20 @@
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import get_similarities
|
||||
from sglang.test.test_utils import get_similarities, is_in_ci
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
|
||||
("marco/mcdse-2b-v1", 1, 1e-5),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -91,7 +93,12 @@ class TestEmbeddingModels(unittest.TestCase):
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size, prefill_tolerance in MODELS:
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
|
||||
|
||||
Reference in New Issue
Block a user