refactor: unify names of the feature field of MultimodalDataItem (#8075)
This commit is contained in:
@@ -78,7 +78,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
output_lengths = (input_lengths - 2) // 2 + 1
|
output_lengths = (input_lengths - 2) // 2 + 1
|
||||||
|
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
audio_features=res["input_features"],
|
feature=res["input_features"],
|
||||||
audio_feature_lens=output_lengths,
|
audio_feature_lens=output_lengths,
|
||||||
audio_offsets=audio_offsets,
|
audio_offsets=audio_offsets,
|
||||||
modality=Modality.AUDIO,
|
modality=Modality.AUDIO,
|
||||||
|
|||||||
@@ -207,13 +207,12 @@ class MultimodalDataItem:
|
|||||||
modality: Modality
|
modality: Modality
|
||||||
hash: int = None
|
hash: int = None
|
||||||
pad_value: int = None
|
pad_value: int = None
|
||||||
image_sizes: Tuple[int, int] = None
|
|
||||||
offsets: Optional[list] = None
|
offsets: Optional[list] = None
|
||||||
|
# the raw features returned by processor, e.g. pixel_values or audio_features
|
||||||
|
feature: Union[torch.Tensor, np.ndarray] = None
|
||||||
|
|
||||||
|
image_sizes: Tuple[int, int] = None
|
||||||
|
|
||||||
# the real data, pixel_values or audio_features
|
|
||||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
|
||||||
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
|
||||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
|
||||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||||
@@ -238,7 +237,6 @@ class MultimodalDataItem:
|
|||||||
image_grid_hws: Optional[List[torch.Tensor]] = None
|
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
# For gemma3n
|
# For gemma3n
|
||||||
input_features: Optional[torch.Tensor] = None
|
|
||||||
input_features_mask: Optional[torch.Tensor] = None
|
input_features_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -254,18 +252,11 @@ class MultimodalDataItem:
|
|||||||
from sglang.srt.managers.mm_utils import hash_feature
|
from sglang.srt.managers.mm_utils import hash_feature
|
||||||
|
|
||||||
if self.hash is None:
|
if self.hash is None:
|
||||||
if self.precomputed_features is not None:
|
if self.feature is not None:
|
||||||
self.hash = hash_feature(self.precomputed_features)
|
hashed_feature = self.feature
|
||||||
elif self.is_audio():
|
|
||||||
if self.audio_features is not None:
|
|
||||||
self.hash = hash_feature(self.audio_features)
|
|
||||||
elif self.input_features is not None:
|
|
||||||
self.hash = hash_feature(self.input_features)
|
|
||||||
elif self.is_video():
|
|
||||||
self.hash = hash_feature(self.pixel_values_videos)
|
|
||||||
else:
|
else:
|
||||||
self.hash = hash_feature(self.pixel_values)
|
hashed_feature = self.precomputed_features
|
||||||
|
self.hash = hash_feature(hashed_feature)
|
||||||
assert self.hash is not None
|
assert self.hash is not None
|
||||||
self.pad_value = self.hash % (1 << 30)
|
self.pad_value = self.hash % (1 << 30)
|
||||||
|
|
||||||
@@ -275,8 +266,7 @@ class MultimodalDataItem:
|
|||||||
def is_audio(self):
|
def is_audio(self):
|
||||||
return (self.modality == Modality.AUDIO) and (
|
return (self.modality == Modality.AUDIO) and (
|
||||||
self.precomputed_features is not None
|
self.precomputed_features is not None
|
||||||
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||||
or not MultimodalDataItem.is_empty_list(self.input_features)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_image(self):
|
def is_image(self):
|
||||||
@@ -284,13 +274,13 @@ class MultimodalDataItem:
|
|||||||
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
||||||
) and (
|
) and (
|
||||||
self.precomputed_features is not None
|
self.precomputed_features is not None
|
||||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_video(self):
|
def is_video(self):
|
||||||
return (self.modality == Modality.VIDEO) and (
|
return (self.modality == Modality.VIDEO) and (
|
||||||
self.precomputed_features is not None
|
self.precomputed_features is not None
|
||||||
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
@@ -311,7 +301,7 @@ class MultimodalDataItem:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other):
|
||||||
self.pixel_values += other.pixel_values
|
self.feature += other.feature
|
||||||
self.image_sizes += other.image_sizes
|
self.image_sizes += other.image_sizes
|
||||||
self.image_offsets += other.image_offsets
|
self.image_offsets += other.image_offsets
|
||||||
self.hash = hash((self.hash, other.hash))
|
self.hash = hash((self.hash, other.hash))
|
||||||
@@ -354,7 +344,6 @@ class MultimodalInputs:
|
|||||||
|
|
||||||
assert isinstance(ret.mm_items, list)
|
assert isinstance(ret.mm_items, list)
|
||||||
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
||||||
|
|
||||||
for item in ret.mm_items:
|
for item in ret.mm_items:
|
||||||
item.set_pad_value()
|
item.set_pad_value()
|
||||||
|
|
||||||
@@ -1278,11 +1267,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if mm_input is None:
|
if mm_input is None:
|
||||||
continue
|
continue
|
||||||
for mm_item in mm_input.mm_items:
|
for mm_item in mm_input.mm_items:
|
||||||
pixel_values = getattr(mm_item, "pixel_values", None)
|
pixel_values = getattr(mm_item, "feature", None)
|
||||||
if isinstance(pixel_values, torch.Tensor):
|
if isinstance(pixel_values, torch.Tensor):
|
||||||
mm_item.pixel_values = pixel_values.to(
|
mm_item.feature = pixel_values.to(self.device, non_blocking=True)
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
self.multimodal_inputs = multimodal_inputs
|
self.multimodal_inputs = multimodal_inputs
|
||||||
self.token_type_ids = token_type_ids_tensor
|
self.token_type_ids = token_type_ids_tensor
|
||||||
self.seq_lens_sum = sum(seq_lens)
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ class CLIPModel(nn.Module):
|
|||||||
if forward_batch.mm_inputs is not None:
|
if forward_batch.mm_inputs is not None:
|
||||||
mm_inputs = forward_batch.mm_inputs
|
mm_inputs = forward_batch.mm_inputs
|
||||||
pixel_values_list = [
|
pixel_values_list = [
|
||||||
item.pixel_values
|
item.feature
|
||||||
for item in flatten_nested_list(
|
for item in flatten_nested_list(
|
||||||
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
|
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
|
pixel_values = torch.concat([item.feature for item in items], dim=0)
|
||||||
bs, n = pixel_values.shape[0:2]
|
bs, n = pixel_values.shape[0:2]
|
||||||
pixel_values = pixel_values.to(
|
pixel_values = pixel_values.to(
|
||||||
device=self.vision_model.device, dtype=self.vision_model.dtype
|
device=self.vision_model.device, dtype=self.vision_model.dtype
|
||||||
|
|||||||
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|||||||
# TODO: can it be batched ?
|
# TODO: can it be batched ?
|
||||||
images_in_this_batch = []
|
images_in_this_batch = []
|
||||||
for item in items:
|
for item in items:
|
||||||
assert item.pixel_values.dim() == 4
|
assert item.feature.dim() == 4
|
||||||
image_feature = self.vision.forward_features(
|
image_feature = self.vision.forward_features(
|
||||||
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
item.feature.type(next(self.vision.parameters()).dtype).to(
|
||||||
device=next(self.vision.parameters()).device
|
device=next(self.vision.parameters()).device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
all_pixel_values = flatten_nested_list([item.feature for item in items])
|
||||||
vision_outputs_list = []
|
vision_outputs_list = []
|
||||||
|
|
||||||
for pixel_values_batch in all_pixel_values:
|
for pixel_values_batch in all_pixel_values:
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
all_pixel_values = flatten_nested_list([item.feature for item in items])
|
||||||
vision_outputs_list = []
|
vision_outputs_list = []
|
||||||
|
|
||||||
for pixel_values_batch in all_pixel_values:
|
for pixel_values_batch in all_pixel_values:
|
||||||
@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
|
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
# Extract audio features and masks from items
|
# Extract audio features and masks from items
|
||||||
all_input_features = flatten_nested_list(
|
all_input_features = flatten_nested_list([item.feature for item in items])
|
||||||
[item.input_features for item in items]
|
|
||||||
)
|
|
||||||
all_input_features_mask = flatten_nested_list(
|
all_input_features_mask = flatten_nested_list(
|
||||||
[~item.input_features_mask for item in items]
|
[~item.input_features_mask for item in items]
|
||||||
) # Note(Xinyuan): reverse the mask according to the HF implementation
|
) # Note(Xinyuan): reverse the mask according to the HF implementation
|
||||||
|
|||||||
@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items])
|
pixel_values = torch.cat([item.feature for item in items])
|
||||||
image_features = self.extract_feature(pixel_values)
|
image_features = self.extract_feature(pixel_values)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
pixel_values = (
|
pixel_values = (
|
||||||
torch.cat([item.pixel_values for item in items], dim=0)
|
torch.cat([item.feature for item in items], dim=0)
|
||||||
.type(self.vision_tower.dtype)
|
.type(self.vision_tower.dtype)
|
||||||
.to(self.vision_tower.device)
|
.to(self.vision_tower.device)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
pixel_values = flatten_nested_list(
|
pixel_values = flatten_nested_list(
|
||||||
[
|
[
|
||||||
[item.pixel_values for item in image_inputs[i].mm_items]
|
[item.feature for item in image_inputs[i].mm_items]
|
||||||
for i in range(bs)
|
for i in range(bs)
|
||||||
if need_vision[i]
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|||||||
features = []
|
features = []
|
||||||
for item in items:
|
for item in items:
|
||||||
# in each item, we assume pixel_values is always batched
|
# in each item, we assume pixel_values is always batched
|
||||||
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
pixel_values, image_sizes = item.feature, item.image_sizes
|
||||||
image_outputs = self.vision_tower(
|
image_outputs = self.vision_tower(
|
||||||
pixel_values, image_sizes, output_hidden_states=True
|
pixel_values, image_sizes, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
pixel_values = flatten_nested_list(
|
pixel_values = flatten_nested_list(
|
||||||
[
|
[
|
||||||
[item.pixel_values for item in image_inputs[i].mm_items]
|
[item.feature for item in image_inputs[i].mm_items]
|
||||||
for i in range(bs)
|
for i in range(bs)
|
||||||
if need_vision[i]
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List[List[torch.Tensor]]: audio embeddings
|
List[List[torch.Tensor]]: audio embeddings
|
||||||
"""
|
"""
|
||||||
wavforms = flatten_nested_list(
|
wavforms = flatten_nested_list([item.feature for item in items if item.feature])
|
||||||
[item.audio_features for item in items if item.audio_features]
|
|
||||||
)
|
|
||||||
# list, [[x1, x2], [y1], [z1]]
|
# list, [[x1, x2], [y1], [z1]]
|
||||||
audio_feature_lens_raw = flatten_nested_list(
|
audio_feature_lens_raw = flatten_nested_list(
|
||||||
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
||||||
@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
List[List[torch.Tensor]]: audio embeddings
|
List[List[torch.Tensor]]: audio embeddings
|
||||||
"""
|
"""
|
||||||
# (bs, 80, frames) or [], multi audios need filled in advance
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||||
wavforms = flatten_nested_list(
|
wavforms = flatten_nested_list([item.feature for item in items if item.feature])
|
||||||
[item.audio_features for item in items if item.audio_features]
|
|
||||||
)
|
|
||||||
# list, [[x1, x2], [y1], [z1]]
|
# list, [[x1, x2], [y1], [z1]]
|
||||||
audio_feature_lens_raw = flatten_nested_list(
|
audio_feature_lens_raw = flatten_nested_list(
|
||||||
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
||||||
@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# list of tensors
|
# list of tensors
|
||||||
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
pixel_values = flatten_nested_list([item.feature for item in items])
|
||||||
tgt_sizes = torch.stack(
|
tgt_sizes = torch.stack(
|
||||||
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# list of tensors
|
# list of tensors
|
||||||
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
pixel_values = flatten_nested_list([item.feature for item in items])
|
||||||
tgt_sizes = torch.stack(
|
tgt_sizes = torch.stack(
|
||||||
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration:
|
|||||||
features = []
|
features = []
|
||||||
for item in items:
|
for item in items:
|
||||||
# in each item, we assume pixel_values is always batched
|
# in each item, we assume pixel_values is always batched
|
||||||
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
pixel_values, image_sizes = item.feature, item.image_sizes
|
||||||
image_outputs = self.vision_tower(
|
image_outputs = self.vision_tower(
|
||||||
pixel_values, image_sizes, output_hidden_states=True
|
pixel_values, image_sizes, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config.text_config)
|
self.logits_processor = LogitsProcessor(config.text_config)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat([item.feature for item in mm_inputs.mm_items], dim=0)
|
||||||
[item.pixel_values for item in mm_inputs.mm_items], dim=0
|
|
||||||
)
|
|
||||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||||
|
|
||||||
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
||||||
@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
if not forward_batch.encoder_cached[i] and mm_input is not None:
|
if not forward_batch.encoder_cached[i] and mm_input is not None:
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat(
|
||||||
[item.pixel_values for item in mm_input.mm_items], dim=0
|
[item.feature for item in mm_input.mm_items], dim=0
|
||||||
)
|
)
|
||||||
max_num_images = max(max_num_images, pixel_values.shape[1])
|
max_num_images = max(max_num_images, pixel_values.shape[1])
|
||||||
|
|
||||||
@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat(
|
||||||
[item.pixel_values for item in mm_input.mm_items], dim=0
|
[item.feature for item in mm_input.mm_items], dim=0
|
||||||
)
|
)
|
||||||
for j in range(pixel_values.shape[1]):
|
for j in range(pixel_values.shape[1]):
|
||||||
img = pixel_values[0, j]
|
img = pixel_values[0, j]
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
raise ValueError("Vision model not available for text-only checkpoint")
|
raise ValueError("Vision model not available for text-only checkpoint")
|
||||||
|
|
||||||
pixel_values = (
|
pixel_values = (
|
||||||
torch.concat([item.pixel_values for item in items])
|
torch.concat([item.feature for item in items])
|
||||||
.to(next(self.vision_model.parameters()).device)
|
.to(next(self.vision_model.parameters()).device)
|
||||||
.type(next(self.vision_model.parameters()).dtype)
|
.type(next(self.vision_model.parameters()).dtype)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -422,9 +422,7 @@ class Phi4MMForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
dtype = next(self.vision_encoder.parameters()).dtype
|
dtype = next(self.vision_encoder.parameters()).dtype
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
|
||||||
dtype
|
|
||||||
)
|
|
||||||
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
|
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
|
||||||
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
||||||
image_embeds = self.vision_encoder(
|
image_embeds = self.vision_encoder(
|
||||||
|
|||||||
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
)
|
)
|
||||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
[getattr(item, "pixel_values_videos") for item in items], dim=0
|
self.visual.dtype
|
||||||
).type(self.visual.dtype)
|
)
|
||||||
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
||||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# Extract audio features from input items
|
# Extract audio features from input items
|
||||||
input_features = torch.cat([item.audio_features for item in items], dim=0).type(
|
input_features = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
self.audio_tower.dtype
|
self.audio_tower.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
)
|
)
|
||||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
[item.pixel_values_videos for item in items], dim=0
|
self.visual.dtype
|
||||||
).type(self.visual.dtype)
|
)
|
||||||
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
||||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class VILAForConditionalGeneration(nn.Module):
|
|||||||
return cast(LogitsProcessorOutput, output)
|
return cast(LogitsProcessorOutput, output)
|
||||||
|
|
||||||
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
||||||
pixel_values = cast(Tensor, mm_input[0].pixel_values)
|
pixel_values = cast(Tensor, mm_input[0].feature)
|
||||||
|
|
||||||
##### BEGIN COPY modeling_vila.py #####
|
##### BEGIN COPY modeling_vila.py #####
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -156,6 +155,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
# "precomputed_features" - handled specially as it can be any modality
|
# "precomputed_features" - handled specially as it can be any modality
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# name of the feature filed
|
||||||
|
# TODO: pass from processors
|
||||||
|
self.FEATURE_NAMES = ["pixel_values", "pixel_values_videos", "audio_features"]
|
||||||
|
|
||||||
def process_mm_data(
|
def process_mm_data(
|
||||||
self, input_text, images=None, videos=None, audios=None, **kwargs
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||||
):
|
):
|
||||||
@@ -524,6 +527,9 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
if modality not in items:
|
if modality not in items:
|
||||||
items[modality] = MultimodalDataItem(modality=modality)
|
items[modality] = MultimodalDataItem(modality=modality)
|
||||||
|
|
||||||
|
if attr_name in self.FEATURE_NAMES:
|
||||||
|
attr_name = "feature"
|
||||||
|
|
||||||
# Set attribute
|
# Set attribute
|
||||||
setattr(items[modality], attr_name, value)
|
setattr(items[modality], attr_name, value)
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
|||||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
image_inputs["mm_items"] = [
|
image_inputs["mm_items"] = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
|
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||||
)
|
)
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
pixel_values=res["images"],
|
feature=res["images"],
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
image_emb_mask=images_seq_mask,
|
image_emb_mask=images_seq_mask,
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
)
|
)
|
||||||
items = [
|
items = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=pixel_values,
|
feature=pixel_values,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|||||||
return {
|
return {
|
||||||
"mm_items": [
|
"mm_items": [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=res["pixel_values"],
|
feature=res["pixel_values"],
|
||||||
image_emb_mask=res["images_emb_mask"],
|
image_emb_mask=res["images_emb_mask"],
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
return {
|
return {
|
||||||
"mm_items": [
|
"mm_items": [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=pixel_values,
|
feature=pixel_values,
|
||||||
image_sizes=image_sizes,
|
image_sizes=image_sizes,
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
if len(pixel_values) != 0:
|
if len(pixel_values) != 0:
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
pixel_values=pixel_values,
|
feature=pixel_values,
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
tgt_size=tgt_sizes_flat,
|
tgt_size=tgt_sizes_flat,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
@@ -135,7 +135,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
else:
|
else:
|
||||||
audio_offsets = None
|
audio_offsets = None
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
audio_features=[res["audio_features"]],
|
feature=[res["audio_features"]],
|
||||||
audio_feature_lens=res["audio_feature_lens"],
|
audio_feature_lens=res["audio_feature_lens"],
|
||||||
offsets=audio_offsets,
|
offsets=audio_offsets,
|
||||||
modality=Modality.AUDIO,
|
modality=Modality.AUDIO,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
|||||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
image_inputs["mm_items"] = [
|
image_inputs["mm_items"] = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=image_inputs["pixel_values"],
|
feature=image_inputs["pixel_values"],
|
||||||
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||||
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|||||||
# Add metadata for image processing
|
# Add metadata for image processing
|
||||||
processor_output["mm_items"] = [
|
processor_output["mm_items"] = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=processor_output["pixel_values"],
|
feature=processor_output["pixel_values"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
items = [
|
items = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=res["input_image_embeds"],
|
feature=res["input_image_embeds"],
|
||||||
image_sizes=res["image_sizes"],
|
image_sizes=res["image_sizes"],
|
||||||
image_emb_mask=res["image_attention_mask"],
|
image_emb_mask=res["image_attention_mask"],
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
|||||||
)
|
)
|
||||||
mm_items = [
|
mm_items = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=processor_output["pixel_values"],
|
feature=processor_output["pixel_values"],
|
||||||
image_sizes=processor_output["image_sizes"],
|
image_sizes=processor_output["image_sizes"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
MultimodalInputs(
|
MultimodalInputs(
|
||||||
mm_items=[
|
mm_items=[
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=pixel_values_flat,
|
feature=pixel_values_flat,
|
||||||
offsets=image_offsets,
|
offsets=image_offsets,
|
||||||
tgt_size=tgt_sizes_flat,
|
tgt_size=tgt_sizes_flat,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
|||||||
Reference in New Issue
Block a user