Skip llama4 vision module loading when multimodal disabled (#8272)
Co-authored-by: Mick <mickjagger19@icloud.com>
This commit is contained in:
@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
"enable_triton_kernel_moe",
|
"enable_triton_kernel_moe",
|
||||||
|
"enable_multimodal",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
Modality,
|
Modality,
|
||||||
MultimodalDataItem,
|
MultimodalDataItem,
|
||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
||||||
self.has_vision = self._has_vision_weights(config)
|
self.has_vision_weights = self._has_vision_weights(config)
|
||||||
if not self.has_vision:
|
if not self.has_vision_weights:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
||||||
"Multimodal capabilities (image processing) will be unavailable."
|
"Multimodal capabilities (image processing) will be unavailable."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.has_vision = (
|
||||||
|
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
|
||||||
|
)
|
||||||
|
|
||||||
if self.has_vision:
|
if self.has_vision:
|
||||||
self.vision_model = Llama4VisionModel(config.vision_config)
|
self.vision_model = Llama4VisionModel(config.vision_config)
|
||||||
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
||||||
@@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def _should_skip_weight(self, name: str) -> bool:
|
def _should_skip_weight(self, name: str) -> bool:
|
||||||
"""Check if we should skip loading this weight."""
|
"""Check if we should skip loading this weight."""
|
||||||
return "vision" in name and not self.has_vision
|
return not self.has_vision and (
|
||||||
|
"vision" in name or "multi_modal_projector" in name
|
||||||
|
)
|
||||||
|
|
||||||
def _transform_weight_name(self, name: str) -> str:
|
def _transform_weight_name(self, name: str) -> str:
|
||||||
"""Transform weight name by adding language_model prefix if needed."""
|
"""Transform weight name by adding language_model prefix if needed."""
|
||||||
|
|||||||
Reference in New Issue
Block a user