Add support for Qwen2-VL-based embedding models (#2055)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
)
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
"""Run forward pass for Qwen2-VL.
|
||||
|
||||
@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
positions = forward_batch.mrope_positions
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k):
|
||||
return logprobs
|
||||
|
||||
|
||||
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers.util import is_sentence_transformer_model
|
||||
|
||||
if is_sentence_transformer_model(model_path):
|
||||
model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
else: # if no pre-trained sentence-transformers model
|
||||
from sentence_transformers import models
|
||||
|
||||
word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
|
||||
pooling_model = models.Pooling(
|
||||
word_embedding_model.get_word_embedding_dimension(),
|
||||
pooling_mode="lasttoken",
|
||||
)
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
||||
|
||||
return model.cuda()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
output_strs: List[str] = None
|
||||
@@ -114,12 +136,9 @@ class HFRunner:
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
).cuda()
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
Reference in New Issue
Block a user