Add support for Qwen2-VL-based embedding models (#2055)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
)
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
"""Run forward pass for Qwen2-VL.
|
||||
|
||||
@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
positions = forward_batch.mrope_positions
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user