feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.connector import (
|
||||
ConnectorType,
|
||||
create_remote_connector,
|
||||
get_connector_type,
|
||||
)
|
||||
from sglang.srt.connector.utils import parse_model_name
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
set_runai_streamer_env,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
||||
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
class RemoteModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load Tensors from remote database."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
# TODO @DellCurry: move to s3 connector only
|
||||
set_runai_streamer_env(load_config)
|
||||
|
||||
def _get_weights_iterator_kv(
|
||||
self,
|
||||
client,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights from remote storage."""
|
||||
assert get_connector_type(client) == ConnectorType.KV
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
return client.weight_iterator(rank)
|
||||
|
||||
def _get_weights_iterator_fs(
|
||||
self,
|
||||
client,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights from remote storage."""
|
||||
assert get_connector_type(client) == ConnectorType.FS
|
||||
return client.weight_iterator()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
model_path: str,
|
||||
url: str,
|
||||
) -> None:
|
||||
with create_remote_connector(url) as client:
|
||||
assert get_connector_type(client) == ConnectorType.KV
|
||||
model_name = parse_model_name(url)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
for key, tensor in state_dict.items():
|
||||
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
||||
client.set(r_key, tensor)
|
||||
|
||||
for root, _, files in os.walk(model_path):
|
||||
for file_name in files:
|
||||
# ignore hidden files
|
||||
if file_name.startswith("."):
|
||||
continue
|
||||
if os.path.splitext(file_name)[1] not in (
|
||||
".bin",
|
||||
".pt",
|
||||
".safetensors",
|
||||
):
|
||||
file_path = os.path.join(root, file_name)
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
file_content = file.read()
|
||||
f_key = f"{model_name}/files/{file_name}"
|
||||
client.setstr(f_key, file_content)
|
||||
|
||||
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
weights_iterator = self._get_weights_iterator_kv(client)
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
for key, tensor in weights_iterator:
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def _load_model_from_remote_fs(
|
||||
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||
) -> nn.Module:
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model.load_weights(self._get_weights_iterator_fs(client))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
logger.info("Loading weights from remote storage ...")
|
||||
start = time.perf_counter()
|
||||
load_config = self.load_config
|
||||
|
||||
assert load_config.load_format == LoadFormat.REMOTE, (
|
||||
f"Model loader {self.load_config.load_format} is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
model_weights = model_config.model_path
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config)
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
with create_remote_connector(model_weights, device_config.device) as client:
|
||||
connector_type = get_connector_type(client)
|
||||
if connector_type == ConnectorType.KV:
|
||||
self._load_model_from_remote_kv(model, client)
|
||||
elif connector_type == ConnectorType.FS:
|
||||
self._load_model_from_remote_fs(
|
||||
model, client, model_config, device_config
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.LAYERED:
|
||||
return LayeredModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.REMOTE:
|
||||
return RemoteModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
Reference in New Issue
Block a user