refactor: move image processors to separate files (#4229)
This commit is contained in:
@@ -41,7 +41,6 @@ from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads_per_partition = divide(self.num_heads, tp_size)
|
||||
self.self_attn = VisionAttention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=num_heads_per_partition,
|
||||
num_heads=self.num_heads,
|
||||
projection_size=config.intermediate_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=config.attention_dropout,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=True,
|
||||
softmax_in_single_precision=True,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
pad_values: List[int],
|
||||
im_start_id: torch.Tensor,
|
||||
im_end_id: torch.Tensor,
|
||||
slice_start_id: Optional[torch.Tensor] = None,
|
||||
slice_end_id: Optional[torch.Tensor] = None,
|
||||
im_start_id: int,
|
||||
im_end_id: int,
|
||||
slice_start_id: Optional[int] = None,
|
||||
slice_end_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor indicating the bounds (start and end token ids) of the images
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_cond = input_ids == im_start_id[0]
|
||||
end_cond = input_ids == im_end_id[0]
|
||||
start_cond = input_ids == im_start_id
|
||||
end_cond = input_ids == im_end_id
|
||||
if slice_start_id is not None:
|
||||
start_cond |= input_ids == slice_start_id[0]
|
||||
end_cond |= input_ids == slice_end_id[0]
|
||||
start_cond |= input_ids == slice_start_id
|
||||
end_cond |= input_ids == slice_end_id
|
||||
|
||||
(image_start_tokens,) = torch.where(start_cond)
|
||||
image_start_tokens += 1
|
||||
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
if (
|
||||
len(image_start_tokens) + 1 == len(image_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and len(image_start_tokens) != 0
|
||||
and len(image_end_tokens) != 0
|
||||
and image_end_tokens[0] < image_start_tokens[0]
|
||||
):
|
||||
image_start_tokens = torch.cat(
|
||||
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
|
||||
None
|
||||
]:
|
||||
if (
|
||||
forward_batch.image_inputs is not None
|
||||
and len(forward_batch.image_inputs) > 0
|
||||
and forward_batch.image_inputs[0] is not None
|
||||
):
|
||||
# TODO: bath
|
||||
kwargs.update(
|
||||
{
|
||||
"pixel_values": (
|
||||
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
|
||||
image_inputs.im_end_id, list
|
||||
):
|
||||
return input_ids
|
||||
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = (
|
||||
image_inputs.im_start_id[0].item()
|
||||
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
|
||||
else image_inputs.im_start_id[0]
|
||||
)
|
||||
im_end_id = (
|
||||
image_inputs.im_end_id[0].item()
|
||||
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
|
||||
else image_inputs.im_end_id[0]
|
||||
)
|
||||
slice_start_id = (
|
||||
image_inputs.slice_start_id[0].item()
|
||||
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
|
||||
else image_inputs.slice_start_id[0]
|
||||
)
|
||||
slice_end_id = (
|
||||
image_inputs.slice_end_id[0].item()
|
||||
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
|
||||
else image_inputs.slice_end_id[0]
|
||||
)
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
slice_start_id: int = image_inputs.slice_start_id
|
||||
slice_end_id: int = image_inputs.slice_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [
|
||||
i
|
||||
for i, x in enumerate(input_ids)
|
||||
if x == im_start_id or x == slice_start_id
|
||||
]
|
||||
end_indices = [
|
||||
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
|
||||
]
|
||||
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(
|
||||
input_ids[last_idx : start_idx + 1]
|
||||
) # include start token
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
|
||||
|
||||
@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
quant_config=None,
|
||||
dropout=0.0,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=False,
|
||||
softmax_in_single_precision=False,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
softmax_in_single_precision = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
|
||||
x = blk(
|
||||
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = image_inputs.im_start_id
|
||||
im_end_id = image_inputs.im_end_id
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
softmax_in_single_precision = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == self.config.image_token_id
|
||||
]
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[image_cnt]
|
||||
)
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
pixel_values = image.pixel_values.clone()
|
||||
|
||||
pixel_values = torch.tensor(image.pixel_values, device="cuda")
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = (
|
||||
start_idx + (image_offset - prefix_len) + num_image_tokens
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len + 1)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||
]
|
||||
image_embeds_offset += num_image_tokens
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
Reference in New Issue
Block a user