feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -585,6 +585,51 @@ def composed_weight_loader(
return composed_loader
def runai_safetensors_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
from runai_model_streamer import SafetensorsStreamer
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
def set_runai_streamer_env(load_config: LoadConfig):
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,