diff --git a/examples/runtime/engine/save_remote_state.py b/examples/runtime/engine/save_remote_state.py index 47812695f..89afa5949 100644 --- a/examples/runtime/engine/save_remote_state.py +++ b/examples/runtime/engine/save_remote_state.py @@ -34,6 +34,12 @@ parser.add_argument( type=str, help="remote address to store model weights", ) +parser.add_argument( + "--remote-draft-model-save-url", + default=None, + type=str, + help="remote address to store draft model weights", +) def main(args): @@ -43,7 +49,10 @@ def main(args): raise ValueError("model path must be a local directory") # Create LLM instance from arguments llm = Engine(**dataclasses.asdict(engine_args)) - llm.save_remote_model(url=args.remote_model_save_url) + llm.save_remote_model( + url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url + ) + print("save remote (draft) model successfully") if __name__ == "__main__": diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 9da66a3ec..0edfa92ae 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -126,6 +126,14 @@ def get_config( kwargs["gguf_file"] = model model = Path(model).parent + if is_remote_url(model): + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(model) + client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + model = client.get_local_dir() + config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 8da3d07be..fdae2142c 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin: url = params["url"] worker = self.tp_worker.worker - worker.model_runner.save_remote_model(url) + if self.draft_worker is not None: + draft_url = params.get("draft_url", None) + assert ( + draft_url is not None + ), "draft_url must be provided when draft model is enabled" + draft_worker = self.draft_worker.worker + draft_worker.model_runner.save_remote_model(draft_url) + def save_sharded_model(self, params): worker = self.tp_worker.worker