diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 4dd368a15..c2608f1f1 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -935,6 +935,19 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="llama_4_vision", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA4, + sep="", + stop_str="<|eot|>", + image_token="<|image|>", + ) +) + @register_conv_template_matching_function def match_internvl(model_path: str): @@ -943,9 +956,11 @@ def match_internvl(model_path: str): @register_conv_template_matching_function -def match_llama_3_vision(model_path: str): +def match_llama_vision(model_path: str): if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE): return "llama_3_vision" + if re.search(r"llama.*4.*", model_path, re.IGNORECASE): + return "llama_4_vision" @register_conv_template_matching_function diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 94abc80df..fb2f4b97b 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -248,7 +248,9 @@ def _get_chunked_prefill_embedding( ) -> Optional[torch.Tensor]: # Calculate embedding for each request, try to get it from cache to avoid repeated calculation embedding_list = [] - for i in range(len(items_size) - 1): + # FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length) + max_iterations = min(len(items_size) - 1, len(prefix_length)) + for i in range(max_iterations): if items_size[i] == items_size[i + 1]: continue embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]] @@ -269,7 +271,7 @@ def _get_chunked_prefill_embedding( embedding_per_req_chunk, _, end_index = get_embedding_chunk( embedding=embedding_per_req, extend_prefix_len=prefix_length[i], - extend_seq_len=extend_length[i], + extend_seq_len=extend_length[i] if i < len(extend_length) else 0, items_offset=items_offset, ) # remove this item from cache if chunk reaches to the end diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index 5360b2144..ff3ebe2d1 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -60,70 +60,72 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ) # Handle image resolutions and aspect ratios - if "pixel_values" in processor_output: - image_processor = processor.image_processor - tokenizer = self._processor.tokenizer + if "pixel_values" not in processor_output: # no image processed + return None - # Calculate tile size and find supported resolutions - tile_size = self.vision_config.image_size - max_num_tiles = getattr(self.vision_config, "max_patches", 1) + image_processor = processor.image_processor + tokenizer = self._processor.tokenizer - possible_resolutions = find_supported_resolutions( - max_num_chunks=max_num_tiles, - patch_size=SizeDict(height=tile_size, width=tile_size), + # Calculate tile size and find supported resolutions + tile_size = self.vision_config.image_size + max_num_tiles = getattr(self.vision_config, "max_patches", 1) + + possible_resolutions = find_supported_resolutions( + max_num_chunks=max_num_tiles, + patch_size=SizeDict(height=tile_size, width=tile_size), + ) + + # Find best fit for each image + best_fit_sizes = [ + get_best_fit( + (image.size[1], image.size[0]), # (height, width) + torch.tensor(possible_resolutions), + resize_to_max_canvas=image_processor.resize_to_max_canvas, ) + for image in processed_data.images + ] - # Find best fit for each image - best_fit_sizes = [ - get_best_fit( - (image.size[1], image.size[0]), # (height, width) - torch.tensor(possible_resolutions), - resize_to_max_canvas=image_processor.resize_to_max_canvas, - ) - for image in processed_data.images - ] + # Calculate aspect ratios and patches per image + aspect_ratios = [ + (image_size[0] // tile_size, image_size[1] // tile_size) + for image_size in best_fit_sizes + ] - # Calculate aspect ratios and patches per image - aspect_ratios = [ - (image_size[0] // tile_size, image_size[1] // tile_size) - for image_size in best_fit_sizes - ] + patches_per_image = [ + 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios + ] - patches_per_image = [ - 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios - ] + # Add to image_inputs + processor_output["aspect_ratios"] = aspect_ratios + processor_output["patches_per_image"] = torch.tensor(patches_per_image) - # Add to image_inputs - processor_output["aspect_ratios"] = aspect_ratios - processor_output["patches_per_image"] = torch.tensor(patches_per_image) + # Process embed_is_patch + vocab = tokenizer.get_vocab() + patch_id = vocab.get(processor.img_patch_token, -1) + image_end_id = vocab.get(processor.end_of_img_token, -1) - # Process embed_is_patch - vocab = tokenizer.get_vocab() - patch_id = vocab.get(processor.img_patch_token, -1) - image_end_id = vocab.get(processor.end_of_img_token, -1) + if patch_id != -1 and image_end_id != -1: + input_ids = processor_output["input_ids"].view(-1) - if patch_id != -1 and image_end_id != -1: - input_ids = processor_output["input_ids"].view(-1) + # Remove BOS token if present + if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id: + input_ids = input_ids[1:] - # Remove BOS token if present - if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id: - input_ids = input_ids[1:] + # Find image end indices and split input_ids + image_end_indices = (input_ids == image_end_id).nonzero().view(-1) - # Find image end indices and split input_ids - image_end_indices = (input_ids == image_end_id).nonzero().view(-1) + if image_end_indices.size(0) > 0: + # Split at image boundaries + split_indices = (image_end_indices + 1)[:-1] + split_input_ids = torch.tensor_split(input_ids, split_indices) + split_input_ids = [x for x in split_input_ids if x.numel() > 0] - if image_end_indices.size(0) > 0: - # Split at image boundaries - split_indices = (image_end_indices + 1)[:-1] - split_input_ids = torch.tensor_split(input_ids, split_indices) - split_input_ids = [x for x in split_input_ids if x.numel() > 0] + # Create embed_is_patch for each image + embed_is_patch = [] + for per_image_input_ids in split_input_ids: + embed_is_patch.append(per_image_input_ids == patch_id) - # Create embed_is_patch for each image - embed_is_patch = [] - for per_image_input_ids in split_input_ids: - embed_is_patch.append(per_image_input_ids == patch_id) - - processor_output["embed_is_patch"] = embed_is_patch + processor_output["embed_is_patch"] = embed_is_patch # Convert to the format expected by SGLang processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]