diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a725670ae..e19707340 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "torchao_config", "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", + "weight_loader_disable_mmap", ] # Put some global args for easy access diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 6ba31b515..0aebe2f9f 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader): hf_weights_files, ) elif use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) + from sglang.srt.managers.schedule_batch import global_server_args_dict + + weight_loader_disable_mmap = global_server_args_dict.get( + "weight_loader_disable_mmap" + ) + weights_iterator = safetensors_weights_iterator( + hf_weights_files, disable_mmap=weight_loader_disable_mmap + ) else: weights_iterator = pt_weights_iterator(hf_weights_files) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index e61a521e1..722f8e1d4 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -422,6 +422,7 @@ def safetensors_weights_iterator( hf_weights_files: List[str], is_all_weights_sharded: bool = False, decryption_key: Optional[str] = None, + disable_mmap: bool = False, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. @@ -443,7 +444,11 @@ def safetensors_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - result = safetensors.torch.load_file(st_file, device="cpu") + if disable_mmap: + with open(st_file, "rb") as f: + result = safetensors.torch.load(f.read()) + else: + result = safetensors.torch.load_file(st_file, device="cpu") for name, param in result.items(): yield name, param diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2f4e08cfe..e5b1c1809 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -237,6 +237,7 @@ class ServerArgs: # For model weight update custom_weight_loader: Optional[List[str]] = None + weight_loader_disable_mmap: bool = False def __post_init__(self): # Expert parallelism @@ -1599,6 +1600,11 @@ class ServerArgs: default=None, help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", ) + parser.add_argument( + "--weight-loader-disable-mmap", + action="store_true", + help="Disable mmap while loading weight using safetensors.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):