feat(draft_model): support draft_model for RemoteModelLoader (#6407)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-08-29 07:09:52 +08:00
committed by GitHub
parent 74dd4249ac
commit a38c149758
3 changed files with 26 additions and 2 deletions

View File

@@ -126,6 +126,14 @@ def get_config(
kwargs["gguf_file"] = model
model = Path(model).parent
if is_remote_url(model):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(model)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)

View File

@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
if self.draft_worker is not None:
draft_url = params.get("draft_url", None)
assert (
draft_url is not None
), "draft_url must be provided when draft model is enabled"
draft_worker = self.draft_worker.worker
draft_worker.model_runner.save_remote_model(draft_url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker