diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index cc14f691f..e5da78368 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -241,12 +241,13 @@ class BaseMultimodalProcessor(ABC): return_tensors="pt", **kwargs, ) - # move feature tensors to cpu - for feature_name in self.FEATURE_NAMES: - if feature_name in result and isinstance( - result[feature_name], torch.Tensor - ): - result[feature_name] = result[feature_name].to("cpu") + if not self.server_args.keep_mm_feature_on_device: + # move feature tensors to cpu + for feature_name in self.FEATURE_NAMES: + if feature_name in result and isinstance( + result[feature_name], torch.Tensor + ): + result[feature_name] = result[feature_name].to("cpu") return result diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8ee1e8f27..34dae30fb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -381,6 +381,7 @@ class ServerArgs: disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False + keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None @@ -2213,6 +2214,11 @@ class ServerArgs: action="store_true", help="Adopt base image processor instead of fast image processor.", ) + parser.add_argument( + "--keep-mm-feature-on-device", + action="store_true", + help="Keep multimodal feature tensors on device after processing to save D2H copy.", + ) parser.add_argument( "--enable-return-hidden-states", action="store_true",