Organize image inputs (#1531)

This commit is contained in:
Liangsheng Yin
2024-09-28 23:28:55 -07:00
committed by GitHub
parent e165a9fc1b
commit fd9ad817ec
8 changed files with 121 additions and 132 deletions

View File

@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
offset_list = []
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
pad_ids = pad_values * (
(new_image_feature_len + len(pad_values)) // len(pad_values)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
return input_ids, offset_list
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
for modalities in input_metadata.modalities:
if modalities is not None:
modalities_list.extend(modalities)
max_image_offset = []
for im in image_inputs:
if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im and im.image_offsets is not None:
max_image_offset.append(max(im.image_offsets))
else:
max_image_offset.append(-1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Whether the requests need vision inputs
max_image_offset = np.array(
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
need_vision = start_positions <= max_image_offset
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]
########## Encode Image ########