diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index a1b2d4723..120eb04ee 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -136,6 +136,7 @@ def load_model(server_args, port_args, tp_rank): context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, ) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index d17add769..7aaee1547 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -43,10 +43,12 @@ class ModelConfig: context_length: Optional[int] = None, model_override_args: Optional[str] = None, is_embedding: Optional[bool] = None, + enable_multimodal: Optional[bool] = None, dtype: str = "auto", quantization: Optional[str] = None, override_config_file: Optional[str] = None, ) -> None: + self.model_path = model_path self.revision = revision self.quantization = quantization @@ -70,14 +72,28 @@ class ModelConfig: self.hf_text_config, "attention_chunk_size", None ) + if enable_multimodal is None: + if self.hf_config.architectures == "Llama4ForConditionalGeneration": + enable_multimodal = False + else: + enable_multimodal = True + # Check model type self.is_generation = is_generation_model( self.hf_config.architectures, is_embedding ) - self.is_multimodal = is_multimodal_model(self.hf_config.architectures) - self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures) - self.is_image_gen = is_image_gen_model(self.hf_config.architectures) - self.is_audio_model = is_audio_model(self.hf_config.architectures) + self.is_multimodal = enable_multimodal and is_multimodal_model( + self.hf_config.architectures + ) + self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model( + self.hf_config.architectures + ) + self.is_image_gen = enable_multimodal and is_image_gen_model( + self.hf_config.architectures + ) + self.is_audio_model = enable_multimodal and is_audio_model( + self.hf_config.architectures + ) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 359573fc6..2c69191c3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -437,6 +437,7 @@ class Scheduler( context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 33afffbd6..9f18ae63c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -163,6 +163,7 @@ class TokenizerManager: context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 174f2e533..a79ea3281 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -68,6 +68,7 @@ class TpModelWorker: context_length=server_args.context_length, model_override_args=server_args.json_model_override_args, is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 00244f90f..f49cdae60 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -281,7 +281,6 @@ class ModelRunner: f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"because this is a multimodal model." ) - logger.info( "Automatically turn off --chunked-prefill-size for multimodal model." ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 28539dcee..6580f7688 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -156,6 +156,7 @@ class ServerArgs: disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False + enable_llama4_multimodal: Optional[bool] = None disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -294,6 +295,8 @@ class ServerArgs: f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal + # Data parallelism attention if self.enable_dp_attention: self.schedule_conservativeness = self.schedule_conservativeness * 0.3 @@ -974,6 +977,11 @@ class ServerArgs: action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.", ) + parser.add_argument( + "--enable-llama4-multimodal", + action="store_true", + help="Enable the multimodal functionality for Llama-4.", + ) parser.add_argument( "--disable-overlap-schedule", action="store_true",