Organize image inputs (#1531)
This commit is contained in:
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
||||
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||
offset_list = []
|
||||
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_w = int(new_w // times)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
pad_ids = pad_values * (
|
||||
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
||||
try:
|
||||
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
return input_ids, offset_list
|
||||
|
||||
image_inputs.image_offsets = offset_list
|
||||
return input_ids
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
for modalities in input_metadata.modalities:
|
||||
if modalities is not None:
|
||||
modalities_list.extend(modalities)
|
||||
max_image_offset = []
|
||||
for im in image_inputs:
|
||||
if im and im.modalities is not None:
|
||||
modalities_list.extend(im.modalities)
|
||||
if im and im.image_offsets is not None:
|
||||
max_image_offset.append(max(im.image_offsets))
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||
pixel_values = [
|
||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
||||
]
|
||||
image_sizes = [
|
||||
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
||||
]
|
||||
image_offsets = [
|
||||
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
||||
]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
|
||||
Reference in New Issue
Block a user