[feat] Add detail in image_data (#8596)
This commit is contained in:
@@ -30,8 +30,10 @@ import re
|
|||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from sglang.srt.utils import read_system_prompt_from_file
|
from sglang.srt.utils import ImageData, read_system_prompt_from_file
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(IntEnum):
|
class SeparatorStyle(IntEnum):
|
||||||
@@ -91,7 +93,7 @@ class Conversation:
|
|||||||
video_token: str = "<video>"
|
video_token: str = "<video>"
|
||||||
audio_token: str = "<audio>"
|
audio_token: str = "<audio>"
|
||||||
|
|
||||||
image_data: Optional[List[str]] = None
|
image_data: Optional[List[ImageData]] = None
|
||||||
video_data: Optional[List[str]] = None
|
video_data: Optional[List[str]] = None
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
stop_token_ids: Optional[int] = None
|
stop_token_ids: Optional[int] = None
|
||||||
@@ -381,9 +383,9 @@ class Conversation:
|
|||||||
"""Append a new message."""
|
"""Append a new message."""
|
||||||
self.messages.append([role, message])
|
self.messages.append([role, message])
|
||||||
|
|
||||||
def append_image(self, image: str):
|
def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
|
||||||
"""Append a new image."""
|
"""Append a new image."""
|
||||||
self.image_data.append(image)
|
self.image_data.append(ImageData(url=image, detail=detail))
|
||||||
|
|
||||||
def append_video(self, video: str):
|
def append_video(self, video: str):
|
||||||
"""Append a new video."""
|
"""Append a new video."""
|
||||||
@@ -627,7 +629,9 @@ def generate_chat_conv(
|
|||||||
real_content = image_token + real_content
|
real_content = image_token + real_content
|
||||||
else:
|
else:
|
||||||
real_content += image_token
|
real_content += image_token
|
||||||
conv.append_image(content.image_url.url)
|
conv.append_image(
|
||||||
|
content.image_url.url, content.image_url.detail
|
||||||
|
)
|
||||||
elif content.type == "video_url":
|
elif content.type == "video_url":
|
||||||
real_content += video_token
|
real_content += video_token
|
||||||
conv.append_video(content.video_url.url)
|
conv.append_video(content.video_url.url)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import logging
|
|||||||
import jinja2
|
import jinja2
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
|
from sglang.srt.utils import ImageData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -140,7 +142,12 @@ def process_content_for_template_format(
|
|||||||
chunk_type = chunk.get("type")
|
chunk_type = chunk.get("type")
|
||||||
|
|
||||||
if chunk_type == "image_url":
|
if chunk_type == "image_url":
|
||||||
image_data.append(chunk["image_url"]["url"])
|
image_data.append(
|
||||||
|
ImageData(
|
||||||
|
url=chunk["image_url"]["url"],
|
||||||
|
detail=chunk["image_url"].get("detail", "auto"),
|
||||||
|
)
|
||||||
|
)
|
||||||
if chunk.get("modalities"):
|
if chunk.get("modalities"):
|
||||||
modalities.append(chunk.get("modalities"))
|
modalities.append(chunk.get("modalities"))
|
||||||
# Normalize to simple 'image' type for template compatibility
|
# Normalize to simple 'image' type for template compatibility
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
|
|||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
from sglang.srt.utils import ImageData
|
||||||
|
|
||||||
# Handle serialization of Image for pydantic
|
# Handle serialization of Image for pydantic
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -45,7 +46,7 @@ class SessionParams:
|
|||||||
|
|
||||||
# Type definitions for multimodal input data
|
# Type definitions for multimodal input data
|
||||||
# Individual data item types for each modality
|
# Individual data item types for each modality
|
||||||
ImageDataInputItem = Union[Image, str, Dict]
|
ImageDataInputItem = Union[Image, str, ImageData, Dict]
|
||||||
AudioDataInputItem = Union[str, Dict]
|
AudioDataInputItem = Union[str, Dict]
|
||||||
VideoDataInputItem = Union[str, Dict]
|
VideoDataInputItem = Union[str, Dict]
|
||||||
# Union type for any multimodal data item
|
# Union type for any multimodal data item
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ import traceback
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
@@ -84,6 +85,7 @@ from torch.library import Library
|
|||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
from torch.utils._contextlib import _DecoratorContextManager
|
from torch.utils._contextlib import _DecoratorContextManager
|
||||||
from triton.runtime.cache import FileCacheManager
|
from triton.runtime.cache import FileCacheManager
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
|
|
||||||
@@ -736,9 +738,18 @@ def load_audio(
|
|||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageData:
|
||||||
|
url: str
|
||||||
|
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||||
|
|
||||||
|
|
||||||
def load_image(
|
def load_image(
|
||||||
image_file: Union[Image.Image, str, bytes],
|
image_file: Union[Image.Image, str, ImageData, bytes],
|
||||||
) -> tuple[Image.Image, tuple[int, int]]:
|
) -> tuple[Image.Image, tuple[int, int]]:
|
||||||
|
if isinstance(image_file, ImageData):
|
||||||
|
image_file = image_file.url
|
||||||
|
|
||||||
image = image_size = None
|
image = image_size = None
|
||||||
if isinstance(image_file, Image.Image):
|
if isinstance(image_file, Image.Image):
|
||||||
image = image_file
|
image = image_file
|
||||||
@@ -762,7 +773,7 @@ def load_image(
|
|||||||
elif isinstance(image_file, str):
|
elif isinstance(image_file, str):
|
||||||
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid image: {image}")
|
raise ValueError(f"Invalid image: {image_file}")
|
||||||
|
|
||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
|||||||
|
|
||||||
# Check that image_data was extracted
|
# Check that image_data was extracted
|
||||||
self.assertEqual(len(image_data), 1)
|
self.assertEqual(len(image_data), 1)
|
||||||
self.assertEqual(image_data[0], "http://example.com/image.jpg")
|
self.assertEqual(image_data[0].url, "http://example.com/image.jpg")
|
||||||
|
|
||||||
# Check that content was normalized
|
# Check that content was normalized
|
||||||
expected_content = [
|
expected_content = [
|
||||||
|
|||||||
Reference in New Issue
Block a user