feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -585,6 +585,51 @@ def composed_weight_loader(
|
||||
return composed_loader
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def set_runai_streamer_env(load_config: LoadConfig):
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency")
|
||||
)
|
||||
|
||||
if "memory_limit" in extra_config and isinstance(
|
||||
extra_config.get("memory_limit"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit")
|
||||
)
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
|
||||
Reference in New Issue
Block a user