[VLM RLHF] Take Image input for verl vlm rollout (#4915)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: GeLee <leege233@gmail.com>
This commit is contained in:
@@ -151,10 +151,6 @@ class Engine:
|
|||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
Please refer to `GenerateReqInput` for the documentation.
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
modalities_list = []
|
|
||||||
if image_data is not None:
|
|
||||||
modalities_list.append("image")
|
|
||||||
|
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -165,7 +161,6 @@ class Engine:
|
|||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
token_ids_logprob=token_ids_logprob,
|
token_ids_logprob=token_ids_logprob,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
modalities=modalities_list,
|
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
return_hidden_states=return_hidden_states,
|
return_hidden_states=return_hidden_states,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|||||||
@@ -139,8 +139,6 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
else:
|
else:
|
||||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||||
|
|
||||||
assert isinstance(prompt, str)
|
|
||||||
|
|
||||||
if isinstance(prompt, list) and return_text:
|
if isinstance(prompt, list) and return_text:
|
||||||
assert len(prompt) and isinstance(prompt[0], int)
|
assert len(prompt) and isinstance(prompt[0], int)
|
||||||
prompt = self._processor.tokenizer.decode(prompt)
|
prompt = self._processor.tokenizer.decode(prompt)
|
||||||
@@ -204,7 +202,16 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
image_sizes += frames[0].size * len(frames)
|
image_sizes += frames[0].size * len(frames)
|
||||||
hashes += [hash(image_file)] * len(frames)
|
|
||||||
|
# Generate a hashable value for the image file
|
||||||
|
if isinstance(image_file, Image.Image):
|
||||||
|
# For PIL.Image objects, use the ID as a hashable value
|
||||||
|
hash_value = hash(id(image_file))
|
||||||
|
else:
|
||||||
|
# For other types (strings, etc.), use the regular hash
|
||||||
|
hash_value = hash(image_file)
|
||||||
|
|
||||||
|
hashes += [hash_value] * len(frames)
|
||||||
images += frames
|
images += frames
|
||||||
image_index += 1
|
image_index += 1
|
||||||
if frames_to_process != 0:
|
if frames_to_process != 0:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import List, Union
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
|
|||||||
@@ -566,10 +566,14 @@ def encode_video(video_path, frame_count_limit=None):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
|
def load_image(
|
||||||
|
image_file: Union[Image.Image, str, bytes]
|
||||||
|
) -> tuple[Image.Image, tuple[int, int]]:
|
||||||
image = image_size = None
|
image = image_size = None
|
||||||
|
if isinstance(image_file, Image.Image):
|
||||||
if isinstance(image_file, bytes):
|
image = image_file
|
||||||
|
image_size = (image.width, image.height)
|
||||||
|
elif isinstance(image_file, bytes):
|
||||||
image = Image.open(BytesIO(image_file))
|
image = Image.open(BytesIO(image_file))
|
||||||
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
||||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||||
|
|||||||
Reference in New Issue
Block a user