ci: simplify multi-modality tests by using mixins (#9006)
This commit is contained in:
@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
if videos:
|
||||
kwargs["videos"] = videos
|
||||
if audios:
|
||||
if self.arch in {
|
||||
"Gemma3nForConditionalGeneration",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
if self._processor.__class__.__name__ in {
|
||||
"Gemma3nProcessor",
|
||||
"Qwen2AudioProcessor",
|
||||
}:
|
||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||
kwargs["audio"] = audios
|
||||
|
||||
@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
|
||||
from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image, logger
|
||||
from sglang.srt.utils import ImageData, load_image, logger
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_data: Union[str, bytes, ImageData],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
image_processor = processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
url = image_data.url if isinstance(image_data, ImageData) else image_data
|
||||
image, image_size = load_image(url)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
image_hash = hash(url)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
image_hash = hash(url)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
self,
|
||||
image_data: Union[bytes, str, ImageData],
|
||||
aspect_ratio: str,
|
||||
grid_pinpoints: str,
|
||||
):
|
||||
if self.cpu_executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, ImageData]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
|
||||
Reference in New Issue
Block a user