diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 0b43c6a5a..7207259ea 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -82,7 +82,8 @@ "Get the information of the model.\n", "\n", "- `model_path`: The path/name of the model.\n", - "- `is_generation`: Whether the model is used as generation model or embedding model." + "- `is_generation`: Whether the model is used as generation model or embedding model.\n", + "- `tokenizer_path`: The path/name of the tokenizer." ] }, { @@ -98,7 +99,8 @@ "print_highlight(response_json)\n", "assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n", "assert response_json[\"is_generation\"] is True\n", - "assert response_json.keys() == {\"model_path\", \"is_generation\"}" + "assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n", + "assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}" ] }, { @@ -187,9 +189,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Update Weights\n", + "## Update Weights From Disk\n", "\n", - "Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size." + "Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n", + "\n", + "SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n" ] }, { @@ -200,7 +204,7 @@ "source": [ "# successful update with same architecture and size\n", "\n", - "url = \"http://localhost:30010/update_weights\"\n", + "url = \"http://localhost:30010/update_weights_from_disk\"\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n", "\n", "response = requests.post(url, json=data)\n", @@ -218,7 +222,7 @@ "source": [ "# failed update with different parameter size\n", "\n", - "url = \"http://localhost:30010/update_weights\"\n", + "url = \"http://localhost:30010/update_weights_from_disk\"\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n", "\n", "response = requests.post(url, json=data)\n", diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 25cf459af..5fb8c6e0e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -352,7 +352,7 @@ class FlushCacheReq: @dataclass -class UpdateWeightReqInput: +class UpdateWeightFromDiskReqInput: # The model path with the new weights model_path: str # The format to load the weights @@ -360,7 +360,7 @@ class UpdateWeightReqInput: @dataclass -class UpdateWeightReqOutput: +class UpdateWeightFromDiskReqOutput: success: bool message: str diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1957eeb99..04c18e2e0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import ( ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - UpdateWeightReqInput, - UpdateWeightReqOutput, + UpdateWeightFromDiskReqInput, + UpdateWeightFromDiskReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -506,10 +506,10 @@ class Scheduler: self.flush_cache() elif isinstance(recv_req, AbortReq): self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightReqInput): - success, message = self.update_weights(recv_req) + elif isinstance(recv_req, UpdateWeightFromDiskReqInput): + success, message = self.update_weights_from_disk(recv_req) self.send_to_tokenizer.send_pyobj( - UpdateWeightReqOutput(success, message) + UpdateWeightFromDiskReqOutput(success, message) ) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: @@ -1363,9 +1363,9 @@ class Scheduler: req.to_abort = True break - def update_weights(self, recv_req: UpdateWeightReqInput): - """In-place update of the weights.""" - success, message = self.tp_worker.update_weights(recv_req) + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + """In-place update of the weights from disk.""" + success, message = self.tp_worker.update_weights_from_disk(recv_req) if success: flash_cache_success = self.flush_cache() assert flash_cache_success, "Cache flush failed after updating weights" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 77bc91218..9c1c591dd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -25,6 +25,7 @@ import uuid from typing import Dict, List, Optional, Tuple, Union import fastapi +import torch import uvloop import zmq import zmq.asyncio @@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import ( ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - UpdateWeightReqInput, - UpdateWeightReqOutput, + UpdateWeightFromDiskReqInput, + UpdateWeightFromDiskReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -405,8 +406,10 @@ class TokenizerManager: req = ProfileReq.STOP_PROFILE self.send_to_scheduler.send_pyobj(req) - async def update_weights( - self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None + async def update_weights_from_disk( + self, + obj: UpdateWeightFromDiskReqInput, + request: Optional[fastapi.Request] = None, ): if self.to_create_loop: self.create_handle_loop() @@ -520,10 +523,13 @@ class TokenizerManager: while True: recv_obj: Union[ - BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput + BatchStrOut, + BatchEmbeddingOut, + BatchTokenIDOut, + UpdateWeightFromDiskReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() - if isinstance(recv_obj, UpdateWeightReqOutput): + if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if self.server_args.dp_size == 1: self.model_update_result.set_result(recv_obj) else: # self.server_args.dp_size > 1 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a5d694e77..bdbf58ba7 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -19,7 +19,7 @@ from typing import Optional from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner @@ -155,8 +155,8 @@ class TpModelWorker: embeddings = logits_output.embeddings return embeddings - def update_weights(self, recv_req: UpdateWeightReqInput): - success, message = self.model_runner.update_weights( + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + success, message = self.model_runner.update_weights_from_disk( recv_req.model_path, recv_req.load_format ) return success, message diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index a5412094c..786656271 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -23,7 +23,7 @@ from typing import Optional import psutil import torch -from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs @@ -204,8 +204,8 @@ class TpModelWorkerClient: ) % self.future_token_ids_limit return None, future_next_token_ids - def update_weights(self, recv_req: UpdateWeightReqInput): - success, message = self.worker.update_weights(recv_req) + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + success, message = self.worker.update_weights_from_disk(recv_req) return success, message def __delete__(self): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 18bdf3edc..da311c7ec 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,10 +20,13 @@ import inspect import json import logging import pkgutil +import time from functools import lru_cache -from typing import Optional, Type +from tokenize import tabsize +from typing import Any, Optional, Type, Union import torch +import torch.distributed as dist import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig @@ -319,8 +322,8 @@ class ModelRunner: f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - def update_weights(self, model_path: str, load_format: str): - """Update weights in-place.""" + def update_weights_from_disk(self, model_path: str, load_format: str): + """Update engine weights online from disk.""" from vllm.model_executor.model_loader.loader import ( DefaultModelLoader, device_loading_context, @@ -329,7 +332,7 @@ class ModelRunner: from vllm.model_executor.model_loader.utils import set_default_torch_dtype logger.info( - f"Update weights begin. " + f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 32523bb9d..7eec7cd1f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, OpenSessionReqInput, - UpdateWeightReqInput, + UpdateWeightFromDiskReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -192,11 +192,11 @@ async def stop_profile_async(): ) -@app.post("/update_weights") +@app.post("/update_weights_from_disk") @time_func_latency -async def update_weights(obj: UpdateWeightReqInput, request: Request): - """Update the weights inplace without re-launching the server.""" - success, message = await tokenizer_manager.update_weights(obj, request) +async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): + """Update the weights from disk inplace without re-launching the server.""" + success, message = await tokenizer_manager.update_weights_from_disk(obj, request) content = {"success": success, "message": message} if success: return ORJSONResponse( diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 355294602..86eb6ff4e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -424,6 +424,7 @@ def popen_launch_server( port, *other_args, ] + if api_key: command += ["--api-key", api_key] diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 22d000664..1998fee2f 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -44,7 +44,7 @@ class TestDataParallelism(unittest.TestCase): def test_update_weight(self): response = requests.post( - self.base_url + "/update_weights", + self.base_url + "/update_weights_from_disk", json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, ) @@ -55,7 +55,7 @@ class TestDataParallelism(unittest.TestCase): time.sleep(5) response = requests.post( - self.base_url + "/update_weights", + self.base_url + "/update_weights_from_disk", json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, ) diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index ddb5a5e08..3b2dc0f6f 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -49,7 +49,7 @@ class TestUpdateWeights(unittest.TestCase): def run_update_weights(self, model_path): response = requests.post( - self.base_url + "/update_weights", + self.base_url + "/update_weights_from_disk", json={ "model_path": model_path, },