support custom weight loader for model runner (#7122)

Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
KavioYu
2025-06-17 07:28:15 +08:00
committed by GitHub
parent c64290dcb5
commit 873ae12cee
4 changed files with 64 additions and 0 deletions

View File

@@ -234,6 +234,9 @@ class ServerArgs:
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
# For model weight update
custom_weight_loader: Optional[List[str]] = None
def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
@@ -538,6 +541,9 @@ class ServerArgs:
"1" if self.disable_outlines_disk_cache else "0"
)
if self.custom_weight_loader is None:
self.custom_weight_loader = []
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
@@ -1576,6 +1582,13 @@ class ServerArgs:
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
parser.add_argument(
"--custom-weight-loader",
type=str,
nargs="*",
default=None,
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):