Support weight loading without mmap (#7469)
This commit is contained in:
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"torchao_config",
|
"torchao_config",
|
||||||
"triton_attention_reduce_in_fp32",
|
"triton_attention_reduce_in_fp32",
|
||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
|
"weight_loader_disable_mmap",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
)
|
)
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
|
weight_loader_disable_mmap = global_server_args_dict.get(
|
||||||
|
"weight_loader_disable_mmap"
|
||||||
|
)
|
||||||
|
weights_iterator = safetensors_weights_iterator(
|
||||||
|
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||||
|
|
||||||
|
|||||||
@@ -422,6 +422,7 @@ def safetensors_weights_iterator(
|
|||||||
hf_weights_files: List[str],
|
hf_weights_files: List[str],
|
||||||
is_all_weights_sharded: bool = False,
|
is_all_weights_sharded: bool = False,
|
||||||
decryption_key: Optional[str] = None,
|
decryption_key: Optional[str] = None,
|
||||||
|
disable_mmap: bool = False,
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model safetensor files.
|
"""Iterate over the weights in the model safetensor files.
|
||||||
|
|
||||||
@@ -443,6 +444,10 @@ def safetensors_weights_iterator(
|
|||||||
disable=not enable_tqdm,
|
disable=not enable_tqdm,
|
||||||
bar_format=_BAR_FORMAT,
|
bar_format=_BAR_FORMAT,
|
||||||
):
|
):
|
||||||
|
if disable_mmap:
|
||||||
|
with open(st_file, "rb") as f:
|
||||||
|
result = safetensors.torch.load(f.read())
|
||||||
|
else:
|
||||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||||
for name, param in result.items():
|
for name, param in result.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# For model weight update
|
# For model weight update
|
||||||
custom_weight_loader: Optional[List[str]] = None
|
custom_weight_loader: Optional[List[str]] = None
|
||||||
|
weight_loader_disable_mmap: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
@@ -1599,6 +1600,11 @@ class ServerArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight-loader-disable-mmap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable mmap while loading weight using safetensors.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user