From f6f713797bcbc63d225136d66deaa00495cdedfe Mon Sep 17 00:00:00 2001 From: James Xu Date: Thu, 21 Nov 2024 17:24:25 -0500 Subject: [PATCH] Add support for Qwen2-VL-based embedding models (#2055) --- README.md | 2 +- python/sglang/srt/models/qwen2_vl.py | 17 +++++++++---- python/sglang/test/runners.py | 31 +++++++++++++++++++----- test/srt/models/test_embedding_models.py | 1 + 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index f719b3d9b..2132ed860 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ The core features include: - **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. -- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models. +- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption. ## Getting Started diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index cfd2a2ce7..3d3876243 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ) from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import ImageInputs @@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: pixel_values = image_input["pixel_values"].type(self.visual.dtype) @@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + get_embedding: bool = False, ): """Run forward pass for Qwen2-VL. @@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_inputs = [ img for img in forward_batch.image_inputs if img is not None ] - - positions = forward_batch.mrope_positions + if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + positions = forward_batch.mrope_positions if ( forward_batch.forward_mode.is_decode() or image_inputs is None @@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module): forward_batch=forward_batch, input_embeds=inputs_embeds, ) - return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch - ) + + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 3870c4503..c622a1b25 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -58,6 +58,28 @@ def get_top_logprobs(logits, k): return logprobs +def _get_sentence_transformer_embedding_model(model_path, torch_dtype): + from sentence_transformers import SentenceTransformer + from sentence_transformers.util import is_sentence_transformer_model + + if is_sentence_transformer_model(model_path): + model = SentenceTransformer( + model_path, + model_kwargs={"torch_dtype": torch_dtype}, + ) + else: # if no pre-trained sentence-transformers model + from sentence_transformers import models + + word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype) + pooling_model = models.Pooling( + word_embedding_model.get_word_embedding_dimension(), + pooling_mode="lasttoken", + ) + model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) + + return model.cuda() + + @dataclass class ModelOutput: output_strs: List[str] = None @@ -114,12 +136,9 @@ class HFRunner: low_cpu_mem_usage=True, ).cuda() elif self.model_type == "embedding": - from sentence_transformers import SentenceTransformer - - self.model = SentenceTransformer( - model_path, - model_kwargs={"torch_dtype": torch_dtype}, - ).cuda() + self.model = _get_sentence_transformer_embedding_model( + model_path, torch_dtype + ) elif self.model_type == "reward": from transformers import AutoModelForSequenceClassification diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index f3ed4cdd7..d04943b28 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -25,6 +25,7 @@ from sglang.test.test_utils import get_similarities MODELS = [ ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), ("intfloat/e5-mistral-7b-instruct", 1, 1e-5), + ("marco/mcdse-2b-v1", 1, 1e-5), ] TORCH_DTYPES = [torch.float16]