remove redundant pad_input_ids function (#500)
This commit is contained in:
@@ -323,36 +323,6 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
)
|
)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
|
||||||
new_image_feature_len = self.image_feature_len
|
|
||||||
# now only support spatial_unpad + anyres
|
|
||||||
if self.mm_patch_merge_type.startswith("spatial"):
|
|
||||||
height = width = self.num_patches_per_side
|
|
||||||
if pt_shape[0] > 1:
|
|
||||||
if self.image_aspect_ratio == "anyres":
|
|
||||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
||||||
image_size,
|
|
||||||
self.image_grid_pinpoints,
|
|
||||||
self.vision_tower.config.image_size,
|
|
||||||
)
|
|
||||||
if "unpad" in self.mm_patch_merge_type:
|
|
||||||
h = num_patch_height * height
|
|
||||||
w = num_patch_width * width
|
|
||||||
new_h, new_w = unpad_image_shape(h, w, image_size)
|
|
||||||
new_image_feature_len += new_h * (new_w + 1)
|
|
||||||
|
|
||||||
pad_ids = pad_value * (
|
|
||||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
|
||||||
)
|
|
||||||
offset = input_ids.index(self.config.image_token_index)
|
|
||||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
|
||||||
new_input_ids = (
|
|
||||||
input_ids[:offset]
|
|
||||||
+ pad_ids[:new_image_feature_len]
|
|
||||||
+ input_ids[offset + 1 :]
|
|
||||||
)
|
|
||||||
return new_input_ids, offset
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user