feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ import json
import logging
import math
import os
import time
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.connector import (
ConnectorType,
create_remote_connector,
get_connector_type,
)
from sglang.srt.connector.utils import parse_model_name
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
set_runai_streamer_env,
)
from sglang.srt.utils import (
get_bool_env_var,
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
return model
class RemoteModelLoader(BaseModelLoader):
"""Model loader that can load Tensors from remote database."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# TODO @DellCurry: move to s3 connector only
set_runai_streamer_env(load_config)
def _get_weights_iterator_kv(
self,
client,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights from remote storage."""
assert get_connector_type(client) == ConnectorType.KV
rank = get_tensor_model_parallel_rank()
return client.weight_iterator(rank)
def _get_weights_iterator_fs(
self,
client,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights from remote storage."""
assert get_connector_type(client) == ConnectorType.FS
return client.weight_iterator()
def download_model(self, model_config: ModelConfig) -> None:
pass
@staticmethod
def save_model(
model: torch.nn.Module,
model_path: str,
url: str,
) -> None:
with create_remote_connector(url) as client:
assert get_connector_type(client) == ConnectorType.KV
model_name = parse_model_name(url)
rank = get_tensor_model_parallel_rank()
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for key, tensor in state_dict.items():
r_key = f"{model_name}/keys/rank_{rank}/{key}"
client.set(r_key, tensor)
for root, _, files in os.walk(model_path):
for file_name in files:
# ignore hidden files
if file_name.startswith("."):
continue
if os.path.splitext(file_name)[1] not in (
".bin",
".pt",
".safetensors",
):
file_path = os.path.join(root, file_name)
with open(file_path, encoding="utf-8") as file:
file_content = file.read()
f_key = f"{model_name}/files/{file_name}"
client.setstr(f_key, file_content)
def _load_model_from_remote_kv(self, model: nn.Module, client):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
weights_iterator = self._get_weights_iterator_kv(client)
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for key, tensor in weights_iterator:
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into " "parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
def _load_model_from_remote_fs(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model.load_weights(self._get_weights_iterator_fs(client))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
logger.info("Loading weights from remote storage ...")
start = time.perf_counter()
load_config = self.load_config
assert load_config.load_format == LoadFormat.REMOTE, (
f"Model loader {self.load_config.load_format} is not supported for "
f"load format {load_config.load_format}"
)
model_weights = model_config.model_path
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
with create_remote_connector(model_weights, device_config.device) as client:
connector_type = get_connector_type(client)
if connector_type == ConnectorType.KV:
self._load_model_from_remote_kv(model, client)
elif connector_type == ConnectorType.FS:
self._load_model_from_remote_fs(
model, client, model_config, device_config
)
end = time.perf_counter()
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.LAYERED:
return LayeredModelLoader(load_config)
if load_config.load_format == LoadFormat.REMOTE:
return RemoteModelLoader(load_config)
return DefaultModelLoader(load_config)