Support --enable-llama4-multimodal (#5254)
This commit is contained in:
@@ -136,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
@@ -43,10 +43,12 @@ class ModelConfig:
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[str] = None,
|
||||
is_embedding: Optional[bool] = None,
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
@@ -70,14 +72,28 @@ class ModelConfig:
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
|
||||
if enable_multimodal is None:
|
||||
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
|
||||
enable_multimodal = False
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
|
||||
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
|
||||
self.is_audio_model = is_audio_model(self.hf_config.architectures)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
|
||||
@@ -437,6 +437,7 @@ class Scheduler(
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
@@ -163,6 +163,7 @@ class TokenizerManager:
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
@@ -68,6 +68,7 @@ class TpModelWorker:
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
@@ -281,7 +281,6 @@ class ModelRunner:
|
||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||
f"because this is a multimodal model."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size for multimodal model."
|
||||
)
|
||||
|
||||
@@ -156,6 +156,7 @@ class ServerArgs:
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
enable_llama4_multimodal: Optional[bool] = None
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
@@ -294,6 +295,8 @@ class ServerArgs:
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
@@ -974,6 +977,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-llama4-multimodal",
|
||||
action="store_true",
|
||||
help="Enable the multimodal functionality for Llama-4.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-schedule",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user