Fix llama4 vision (#7840)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -935,6 +935,19 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="llama_4_vision",
|
||||||
|
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
||||||
|
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA4,
|
||||||
|
sep="",
|
||||||
|
stop_str="<|eot|>",
|
||||||
|
image_token="<|image|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
def match_internvl(model_path: str):
|
def match_internvl(model_path: str):
|
||||||
@@ -943,9 +956,11 @@ def match_internvl(model_path: str):
|
|||||||
|
|
||||||
|
|
||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
def match_llama_3_vision(model_path: str):
|
def match_llama_vision(model_path: str):
|
||||||
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
||||||
return "llama_3_vision"
|
return "llama_3_vision"
|
||||||
|
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
||||||
|
return "llama_4_vision"
|
||||||
|
|
||||||
|
|
||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
|
|||||||
@@ -248,7 +248,9 @@ def _get_chunked_prefill_embedding(
|
|||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||||
embedding_list = []
|
embedding_list = []
|
||||||
for i in range(len(items_size) - 1):
|
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
|
||||||
|
max_iterations = min(len(items_size) - 1, len(prefix_length))
|
||||||
|
for i in range(max_iterations):
|
||||||
if items_size[i] == items_size[i + 1]:
|
if items_size[i] == items_size[i + 1]:
|
||||||
continue
|
continue
|
||||||
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
||||||
@@ -269,7 +271,7 @@ def _get_chunked_prefill_embedding(
|
|||||||
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
||||||
embedding=embedding_per_req,
|
embedding=embedding_per_req,
|
||||||
extend_prefix_len=prefix_length[i],
|
extend_prefix_len=prefix_length[i],
|
||||||
extend_seq_len=extend_length[i],
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
||||||
items_offset=items_offset,
|
items_offset=items_offset,
|
||||||
)
|
)
|
||||||
# remove this item from cache if chunk reaches to the end
|
# remove this item from cache if chunk reaches to the end
|
||||||
|
|||||||
@@ -60,70 +60,72 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle image resolutions and aspect ratios
|
# Handle image resolutions and aspect ratios
|
||||||
if "pixel_values" in processor_output:
|
if "pixel_values" not in processor_output: # no image processed
|
||||||
image_processor = processor.image_processor
|
return None
|
||||||
tokenizer = self._processor.tokenizer
|
|
||||||
|
|
||||||
# Calculate tile size and find supported resolutions
|
image_processor = processor.image_processor
|
||||||
tile_size = self.vision_config.image_size
|
tokenizer = self._processor.tokenizer
|
||||||
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
|
||||||
|
|
||||||
possible_resolutions = find_supported_resolutions(
|
# Calculate tile size and find supported resolutions
|
||||||
max_num_chunks=max_num_tiles,
|
tile_size = self.vision_config.image_size
|
||||||
patch_size=SizeDict(height=tile_size, width=tile_size),
|
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
||||||
|
|
||||||
|
possible_resolutions = find_supported_resolutions(
|
||||||
|
max_num_chunks=max_num_tiles,
|
||||||
|
patch_size=SizeDict(height=tile_size, width=tile_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find best fit for each image
|
||||||
|
best_fit_sizes = [
|
||||||
|
get_best_fit(
|
||||||
|
(image.size[1], image.size[0]), # (height, width)
|
||||||
|
torch.tensor(possible_resolutions),
|
||||||
|
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||||
)
|
)
|
||||||
|
for image in processed_data.images
|
||||||
|
]
|
||||||
|
|
||||||
# Find best fit for each image
|
# Calculate aspect ratios and patches per image
|
||||||
best_fit_sizes = [
|
aspect_ratios = [
|
||||||
get_best_fit(
|
(image_size[0] // tile_size, image_size[1] // tile_size)
|
||||||
(image.size[1], image.size[0]), # (height, width)
|
for image_size in best_fit_sizes
|
||||||
torch.tensor(possible_resolutions),
|
]
|
||||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
|
||||||
)
|
|
||||||
for image in processed_data.images
|
|
||||||
]
|
|
||||||
|
|
||||||
# Calculate aspect ratios and patches per image
|
patches_per_image = [
|
||||||
aspect_ratios = [
|
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
||||||
(image_size[0] // tile_size, image_size[1] // tile_size)
|
]
|
||||||
for image_size in best_fit_sizes
|
|
||||||
]
|
|
||||||
|
|
||||||
patches_per_image = [
|
# Add to image_inputs
|
||||||
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
processor_output["aspect_ratios"] = aspect_ratios
|
||||||
]
|
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
||||||
|
|
||||||
# Add to image_inputs
|
# Process embed_is_patch
|
||||||
processor_output["aspect_ratios"] = aspect_ratios
|
vocab = tokenizer.get_vocab()
|
||||||
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
patch_id = vocab.get(processor.img_patch_token, -1)
|
||||||
|
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||||
|
|
||||||
# Process embed_is_patch
|
if patch_id != -1 and image_end_id != -1:
|
||||||
vocab = tokenizer.get_vocab()
|
input_ids = processor_output["input_ids"].view(-1)
|
||||||
patch_id = vocab.get(processor.img_patch_token, -1)
|
|
||||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
|
||||||
|
|
||||||
if patch_id != -1 and image_end_id != -1:
|
# Remove BOS token if present
|
||||||
input_ids = processor_output["input_ids"].view(-1)
|
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[1:]
|
||||||
|
|
||||||
# Remove BOS token if present
|
# Find image end indices and split input_ids
|
||||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
||||||
input_ids = input_ids[1:]
|
|
||||||
|
|
||||||
# Find image end indices and split input_ids
|
if image_end_indices.size(0) > 0:
|
||||||
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
# Split at image boundaries
|
||||||
|
split_indices = (image_end_indices + 1)[:-1]
|
||||||
|
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
||||||
|
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
||||||
|
|
||||||
if image_end_indices.size(0) > 0:
|
# Create embed_is_patch for each image
|
||||||
# Split at image boundaries
|
embed_is_patch = []
|
||||||
split_indices = (image_end_indices + 1)[:-1]
|
for per_image_input_ids in split_input_ids:
|
||||||
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||||
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
|
||||||
|
|
||||||
# Create embed_is_patch for each image
|
processor_output["embed_is_patch"] = embed_is_patch
|
||||||
embed_is_patch = []
|
|
||||||
for per_image_input_ids in split_input_ids:
|
|
||||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
|
||||||
|
|
||||||
processor_output["embed_is_patch"] = embed_is_patch
|
|
||||||
|
|
||||||
# Convert to the format expected by SGLang
|
# Convert to the format expected by SGLang
|
||||||
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user