From fbebcb7aa4aa7c7c0d6bab4d915756d616318de1 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 10 Apr 2025 00:28:44 +0800 Subject: [PATCH] model: support mllama4 (#5144) --- python/sglang/srt/configs/model_config.py | 2 +- python/sglang/srt/managers/mm_utils.py | 5 +- .../managers/multimodal_processors/mllama4.py | 57 +++++++---------- python/sglang/srt/managers/schedule_batch.py | 56 ++++++++++++----- python/sglang/srt/models/llama4.py | 3 + python/sglang/srt/models/mllama4.py | 61 +++++++++++++++---- test/srt/test_vision_openai_server.py | 26 +++++++- 7 files changed, 145 insertions(+), 65 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 5f4e2f0dd..d17add769 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -486,8 +486,8 @@ multimodal_model_archs = [ "Gemma3ForConditionalGeneration", "Grok1VForCausalLM", "Grok1AForCausalLM", - # TODO: add multimodal support for "Llama4ForConditionalGeneration", "LlavaLlamaForCausalLM", + "Llama4ForConditionalGeneration", "LlavaMistralForCausalLM", "LlavaQwenForCausalLM", "LlavaVidForCausalLM", diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 6d1a33455..045b4c13a 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -148,7 +148,8 @@ def get_embedding_and_mask( placeholder_tensor, ).unsqueeze(-1) - num_mm_tokens_in_input_ids = special_multimodal_mask.sum() + num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item() + if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: logger.warning( f"Number of tokens in multimodal embedding does not match those in the input text." @@ -172,7 +173,7 @@ def get_embedding_and_mask( embedding = embedding[-num_multimodal:, :] else: raise RuntimeError( - "Insufficient multimodal embedding length. This is an internal error" + f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error" ) return embedding, special_multimodal_mask diff --git a/python/sglang/srt/managers/multimodal_processors/mllama4.py b/python/sglang/srt/managers/multimodal_processors/mllama4.py index 41b6f3835..0d3c289b5 100644 --- a/python/sglang/srt/managers/multimodal_processors/mllama4.py +++ b/python/sglang/srt/managers/multimodal_processors/mllama4.py @@ -1,10 +1,8 @@ -from typing import List, Mapping, Optional, Tuple, Union +from typing import List, Union import torch -from PIL import Image -from transformers import Llama4Processor from transformers.image_utils import SizeDict -from transformers.models.llama4.image_processing_llama4 import ( +from transformers.models.llama4.image_processing_llama4_fast import ( find_supported_resolutions, get_best_fit, ) @@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import ( ) from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration -from sglang.srt.utils import load_image class Mllama4ImageProcessor(BaseMultimodalProcessor): @@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): super().__init__(hf_config, server_args, _processor) self.vision_config = hf_config.vision_config self.text_config = hf_config.text_config + self.boi_token_index = hf_config.boi_token_index + self.eoi_token_index = hf_config.eoi_token_index + self.image_token_index = hf_config.image_token_index self.multimodal_tokens = MultimodalSpecialTokens( image_token=_processor.image_token ) @@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ) # Process the images using the processor - processor = Llama4Processor.from_pretrained( - self.server_args.model_path, **kwargs - ) + processor = self._processor # Process the prompt and images - image_inputs = processor( - text=processed_data.input_text, + processor_output = self.process_mm_data( + input_text=processed_data.input_text, images=processed_data.images, - return_tensors="pt", ) # Handle image resolutions and aspect ratios - if "pixel_values" in image_inputs: + if "pixel_values" in processor_output: image_processor = processor.image_processor tokenizer = self._processor.tokenizer @@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ] # Add to image_inputs - image_inputs["aspect_ratios"] = aspect_ratios - image_inputs["patches_per_image"] = torch.tensor(patches_per_image) + processor_output["aspect_ratios"] = aspect_ratios + processor_output["patches_per_image"] = torch.tensor(patches_per_image) # Process embed_is_patch vocab = tokenizer.get_vocab() @@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): image_end_id = vocab.get(processor.end_of_img_token, -1) if patch_id != -1 and image_end_id != -1: - input_ids = image_inputs["input_ids"].view(-1) + input_ids = processor_output["input_ids"].view(-1) # Remove BOS token if present if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id: @@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): for per_image_input_ids in split_input_ids: embed_is_patch.append(per_image_input_ids == patch_id) - image_inputs["embed_is_patch"] = embed_is_patch + processor_output["embed_is_patch"] = embed_is_patch # Convert to the format expected by SGLang - image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] + processor_output["input_ids"] = processor_output["input_ids"].tolist()[0] + + processor_output["im_start_id"] = self.boi_token_index + processor_output["im_end_id"] = self.eoi_token_index + processor_output["im_token_id"] = self.image_token_index # Add metadata for image processing - image_inputs["mm_items"] = [ + processor_output["mm_items"] = [ MultimodalDataItem( - pixel_values=image_inputs["pixel_values"], + pixel_values=processor_output["pixel_values"], modality=Modality.IMAGE, - # Add additional metadata needed for Llama4 vision processing - embed_is_patch=image_inputs.get("embed_is_patch", None), - aspect_ratios=image_inputs.get("aspect_ratios", None), - patches_per_image=image_inputs.get("patches_per_image", None), ) ] - return image_inputs - - def get_patch_per_chunk(self): - """Calculate patches per chunk based on vision config""" - image_size = self.vision_config.image_size - patch_size = self.vision_config.patch_size - - assert ( - image_size % patch_size == 0 - ), f"chunk size {image_size} should be multiple of patch_size {patch_size}" - - ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2))) - return (image_size // patch_size) ** 2 // ds_ratio + return processor_output diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 337bcdb06..768a9176c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib from enum import Enum, auto # Copyright 2023-2024 SGLang Team @@ -157,7 +158,7 @@ class Modality(Enum): @dataclasses.dataclass class MultimodalDataItem: """ - A single multimodal data, from a single image/video/audio or other + A single multimodal data, from a single image/video/audio or others """ modality: Modality @@ -195,25 +196,54 @@ class MultimodalDataItem: def set_pad_value(self): """ - Set the pad value after first hashign the data + Set the pad value after first hashing the data """ - def tensor_hash(f): - f_list = flatten_nested_list(f) - f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list] - f_cat = torch.concat(f_list).contiguous().numpy().tobytes() - return hash(f_cat) + def data_hash(data) -> int: + hash_bytes = hashlib.sha256(data).digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big", signed=False) + + def tensor_hash(tensor_list) -> int: + """ + hash a tensor or a tensor list + """ + tensor = tensor_list + if isinstance(tensor_list, list): + tensor_list = flatten_nested_list(tensor_list) + tensor_list = [ + x.flatten() if isinstance(x, torch.Tensor) else x + for x in tensor_list + ] + tensor = torch.concat(tensor_list) + + tensor = tensor.detach().contiguous() + + if tensor.dtype == torch.bfloat16: + # memoryview() doesn't support PyTorch's BFloat16 dtype + tensor = tensor.float() + + if tensor.is_cuda: + tensor_cpu = torch.frombuffer( + tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel() + ).clone() + else: + tensor_cpu = tensor + + mv = memoryview(tensor_cpu.numpy()) + return data_hash(mv.tobytes()) def hash_feature(f): if isinstance(f, list): if isinstance(f[0], torch.Tensor): return tensor_hash(f) - return hash(tuple(flatten_nested_list(f))) + return data_hash(tuple(flatten_nested_list(f))) elif isinstance(f, np.ndarray): arr = np.ascontiguousarray(f) arr_bytes = arr.tobytes() - return hash(arr_bytes) - return hash(f) + return data_hash(arr_bytes) + elif isinstance(f, torch.Tensor): + return tensor_hash([f]) + return data_hash(f) if self.is_audio(): self.hash = hash_feature(self.audio_features) @@ -256,7 +286,7 @@ class MultimodalInputs: mrope_position_delta: Optional[torch.Tensor] = None # image - im_token_id: Optional[torch.Tensor] = None + im_token_id: Optional[int] = None im_start_id: Optional[int] = None im_end_id: Optional[int] = None slice_start_id: Optional[int] = None @@ -330,10 +360,8 @@ class MultimodalInputs: # args needed to be merged optional_args = [ - "items", - "image_offsets", + "mm_items", "image_pad_len", - # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images ] for arg in optional_args: self_arg = getattr(self, arg, None) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 0a46305b5..d933f27ae 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM): ): super().__init__(config, quant_config, prefix) + def get_input_embeddings(self): + return self.model.embed_tokens + def _init_model( self, config: Llama4TextConfig, diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 98fc80686..e492089da 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -1,14 +1,19 @@ -# TODO: add Aapted from vllm/mllama4.py from collections.abc import Iterable -from typing import Optional, Set, Tuple +from typing import List, Optional, Set, Tuple import torch from torch import nn -from transformers import Llama4Config +from transformers import Llama4Config, Llama4VisionModel +from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternImageTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix @@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module): self.config = config self.quant_config = quant_config + self.vision_model = Llama4VisionModel(config.vision_config) + self.multi_modal_projector = Llama4MultiModalProjector(config) + # Initialize the language model from sglang.srt.models.llama4 import Llama4ForCausalLM @@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(config.text_config) + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Get all special token IDs + im_token_id: int = mm_inputs.im_token_id + + pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id)) + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_image_feature( + self, + items: List[MultimodalDataItem], + ) -> torch.Tensor: + pixel_values = ( + torch.concat([item.pixel_values for item in items]) + .to(next(self.vision_model.parameters()).device) + .type(next(self.vision_model.parameters()).dtype) + ) + + image_outputs = self.vision_model(pixel_values, output_hidden_states=False) + image_features = image_outputs.last_hidden_state + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + return projected_vision_flat + def forward( self, input_ids: torch.Tensor, @@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module): **kwargs: object, ) -> torch.Tensor: - return self.language_model(input_ids, positions, forward_batch) + hs = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + image_data_embedding_func=self.get_image_feature, + positions=positions, + ) + + return hs def permute_qk_weight_for_rotary( self, @@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module): ) for name, loaded_weight in weights: - - if name.startswith("vision_model") or name.startswith( - "multi_modal_projector" - ): - continue - - name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight) + if not "vision" in name: + name, loaded_weight = self.permute_qk_weight_for_rotary( + name, loaded_weight + ) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + + if "vision" in name: + continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index d5a22a550..40e0de5d2 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase): self.assertGreater(len(video_response), 0) def test_regex(self): - return client = openai.Client(api_key=self.api_key, base_url=self.base_url) regex = ( @@ -683,6 +682,31 @@ class TestJanusProServer(TestOpenAIVisionServer): pass +class TestLlama4Server(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chat-template", + "llama-4", + "--mem-fraction-static", + "0.8", + "--tp-size=8", + "--context-length=8192", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + class TestGemma3itServer(TestOpenAIVisionServer): @classmethod def setUpClass(cls):