feat(draft_model): support draft_model for RemoteModelLoader (#6407)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -34,6 +34,12 @@ parser.add_argument(
|
|||||||
type=str,
|
type=str,
|
||||||
help="remote address to store model weights",
|
help="remote address to store model weights",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote-draft-model-save-url",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="remote address to store draft model weights",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@@ -43,7 +49,10 @@ def main(args):
|
|||||||
raise ValueError("model path must be a local directory")
|
raise ValueError("model path must be a local directory")
|
||||||
# Create LLM instance from arguments
|
# Create LLM instance from arguments
|
||||||
llm = Engine(**dataclasses.asdict(engine_args))
|
llm = Engine(**dataclasses.asdict(engine_args))
|
||||||
llm.save_remote_model(url=args.remote_model_save_url)
|
llm.save_remote_model(
|
||||||
|
url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url
|
||||||
|
)
|
||||||
|
print("save remote (draft) model successfully")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -126,6 +126,14 @@ def get_config(
|
|||||||
kwargs["gguf_file"] = model
|
kwargs["gguf_file"] = model
|
||||||
model = Path(model).parent
|
model = Path(model).parent
|
||||||
|
|
||||||
|
if is_remote_url(model):
|
||||||
|
# BaseConnector implements __del__() to clean up the local dir.
|
||||||
|
# Since config files need to exist all the time, so we DO NOT use
|
||||||
|
# with statement to avoid closing the client.
|
||||||
|
client = create_remote_connector(model)
|
||||||
|
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||||
|
model = client.get_local_dir()
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
|
|||||||
url = params["url"]
|
url = params["url"]
|
||||||
|
|
||||||
worker = self.tp_worker.worker
|
worker = self.tp_worker.worker
|
||||||
|
|
||||||
worker.model_runner.save_remote_model(url)
|
worker.model_runner.save_remote_model(url)
|
||||||
|
|
||||||
|
if self.draft_worker is not None:
|
||||||
|
draft_url = params.get("draft_url", None)
|
||||||
|
assert (
|
||||||
|
draft_url is not None
|
||||||
|
), "draft_url must be provided when draft model is enabled"
|
||||||
|
draft_worker = self.draft_worker.worker
|
||||||
|
draft_worker.model_runner.save_remote_model(draft_url)
|
||||||
|
|
||||||
def save_sharded_model(self, params):
|
def save_sharded_model(self, params):
|
||||||
worker = self.tp_worker.worker
|
worker = self.tp_worker.worker
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user