Organize image inputs (#1531)

This commit is contained in:
Liangsheng Yin
2024-09-28 23:28:55 -07:00
committed by GitHub
parent e165a9fc1b
commit fd9ad817ec
8 changed files with 121 additions and 132 deletions

View File

@@ -194,10 +194,9 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
image_inputs = await self._get_image_inputs(
obj, obj.image_data if not_use_index else obj.image_data[index]
)
modalities = obj.modalities
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
@@ -248,10 +247,7 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0]
)
modalities = obj.modalities
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
@@ -262,15 +258,12 @@ class TokenizerManager:
rid,
input_text,
input_ids,
pixel_values,
image_hashes,
image_sizes,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
@@ -369,24 +362,20 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index])
image_inputs = await self._get_image_inputs(
obj, obj.image_data[index]
)
modalities = obj.modalities
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hashes,
image_sizes,
image_inputs,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
@@ -697,10 +686,11 @@ class TokenizerManager:
)
return top_logprobs
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
if not image_data:
return None, None, None
return None
# TODO: move this into a processor for each vision architecture
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
@@ -741,7 +731,12 @@ class TokenizerManager:
else:
raise ValueError(f"Invalid image data: {image_data}")
return pixel_values, image_hashes, image_sizes
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": obj.modalities,
}
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str