feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
@dataclass

View File

@@ -51,13 +51,14 @@ class ModelConfig:
self.quantization = quantization
# Parse args
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
model_path,
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
@@ -318,6 +319,29 @@ class ModelConfig:
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.