udate weights from disk (#2265)
This commit is contained in:
@@ -82,7 +82,8 @@
|
|||||||
"Get the information of the model.\n",
|
"Get the information of the model.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- `model_path`: The path/name of the model.\n",
|
"- `model_path`: The path/name of the model.\n",
|
||||||
"- `is_generation`: Whether the model is used as generation model or embedding model."
|
"- `is_generation`: Whether the model is used as generation model or embedding model.\n",
|
||||||
|
"- `tokenizer_path`: The path/name of the tokenizer."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -98,7 +99,8 @@
|
|||||||
"print_highlight(response_json)\n",
|
"print_highlight(response_json)\n",
|
||||||
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
|
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
|
||||||
"assert response_json[\"is_generation\"] is True\n",
|
"assert response_json[\"is_generation\"] is True\n",
|
||||||
"assert response_json.keys() == {\"model_path\", \"is_generation\"}"
|
"assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
|
||||||
|
"assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -187,9 +189,11 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Update Weights\n",
|
"## Update Weights From Disk\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
|
"Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n",
|
||||||
|
"\n",
|
||||||
|
"SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -200,7 +204,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# successful update with same architecture and size\n",
|
"# successful update with same architecture and size\n",
|
||||||
"\n",
|
"\n",
|
||||||
"url = \"http://localhost:30010/update_weights\"\n",
|
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
|
||||||
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
|
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = requests.post(url, json=data)\n",
|
"response = requests.post(url, json=data)\n",
|
||||||
@@ -218,7 +222,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# failed update with different parameter size\n",
|
"# failed update with different parameter size\n",
|
||||||
"\n",
|
"\n",
|
||||||
"url = \"http://localhost:30010/update_weights\"\n",
|
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
|
||||||
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
|
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = requests.post(url, json=data)\n",
|
"response = requests.post(url, json=data)\n",
|
||||||
|
|||||||
@@ -352,7 +352,7 @@ class FlushCacheReq:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightReqInput:
|
class UpdateWeightFromDiskReqInput:
|
||||||
# The model path with the new weights
|
# The model path with the new weights
|
||||||
model_path: str
|
model_path: str
|
||||||
# The format to load the weights
|
# The format to load the weights
|
||||||
@@ -360,7 +360,7 @@ class UpdateWeightReqInput:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightReqOutput:
|
class UpdateWeightFromDiskReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ProfileReq,
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
@@ -506,10 +506,10 @@ class Scheduler:
|
|||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
self.abort_request(recv_req)
|
self.abort_request(recv_req)
|
||||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
||||||
success, message = self.update_weights(recv_req)
|
success, message = self.update_weights_from_disk(recv_req)
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
UpdateWeightReqOutput(success, message)
|
UpdateWeightFromDiskReqOutput(success, message)
|
||||||
)
|
)
|
||||||
elif isinstance(recv_req, ProfileReq):
|
elif isinstance(recv_req, ProfileReq):
|
||||||
if recv_req == ProfileReq.START_PROFILE:
|
if recv_req == ProfileReq.START_PROFILE:
|
||||||
@@ -1363,9 +1363,9 @@ class Scheduler:
|
|||||||
req.to_abort = True
|
req.to_abort = True
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
"""In-place update of the weights."""
|
"""In-place update of the weights from disk."""
|
||||||
success, message = self.tp_worker.update_weights(recv_req)
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||||
if success:
|
if success:
|
||||||
flash_cache_success = self.flush_cache()
|
flash_cache_success = self.flush_cache()
|
||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import uuid
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ProfileReq,
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -405,8 +406,10 @@ class TokenizerManager:
|
|||||||
req = ProfileReq.STOP_PROFILE
|
req = ProfileReq.STOP_PROFILE
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights_from_disk(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self,
|
||||||
|
obj: UpdateWeightFromDiskReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
@@ -520,10 +523,13 @@ class TokenizerManager:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_obj: Union[
|
recv_obj: Union[
|
||||||
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
BatchStrOut,
|
||||||
|
BatchEmbeddingOut,
|
||||||
|
BatchTokenIDOut,
|
||||||
|
UpdateWeightFromDiskReqOutput,
|
||||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
|
||||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
self.model_update_result.set_result(recv_obj)
|
self.model_update_result.set_result(recv_obj)
|
||||||
else: # self.server_args.dp_size > 1
|
else: # self.server_args.dp_size > 1
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -155,8 +155,8 @@ class TpModelWorker:
|
|||||||
embeddings = logits_output.embeddings
|
embeddings = logits_output.embeddings
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
success, message = self.model_runner.update_weights(
|
success, message = self.model_runner.update_weights_from_disk(
|
||||||
recv_req.model_path, recv_req.load_format
|
recv_req.model_path, recv_req.load_format
|
||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing import Optional
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -204,8 +204,8 @@ class TpModelWorkerClient:
|
|||||||
) % self.future_token_ids_limit
|
) % self.future_token_ids_limit
|
||||||
return None, future_next_token_ids
|
return None, future_next_token_ids
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
success, message = self.worker.update_weights(recv_req)
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
def __delete__(self):
|
def __delete__(self):
|
||||||
|
|||||||
@@ -20,10 +20,13 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
import time
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Type
|
from tokenize import tabsize
|
||||||
|
from typing import Any, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
@@ -319,8 +322,8 @@ class ModelRunner:
|
|||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights(self, model_path: str, load_format: str):
|
def update_weights_from_disk(self, model_path: str, load_format: str):
|
||||||
"""Update weights in-place."""
|
"""Update engine weights online from disk."""
|
||||||
from vllm.model_executor.model_loader.loader import (
|
from vllm.model_executor.model_loader.loader import (
|
||||||
DefaultModelLoader,
|
DefaultModelLoader,
|
||||||
device_loading_context,
|
device_loading_context,
|
||||||
@@ -329,7 +332,7 @@ class ModelRunner:
|
|||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Update weights begin. "
|
f"Update engine weights online from disk begin. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -192,11 +192,11 @@ async def stop_profile_async():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights_from_disk")
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||||
"""Update the weights inplace without re-launching the server."""
|
"""Update the weights from disk inplace without re-launching the server."""
|
||||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
if success:
|
if success:
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
|
|||||||
@@ -424,6 +424,7 @@ def popen_launch_server(
|
|||||||
port,
|
port,
|
||||||
*other_args,
|
*other_args,
|
||||||
]
|
]
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
command += ["--api-key", api_key]
|
command += ["--api-key", api_key]
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
|
|
||||||
def test_update_weight(self):
|
def test_update_weight(self):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/update_weights",
|
self.base_url + "/update_weights_from_disk",
|
||||||
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/update_weights",
|
self.base_url + "/update_weights_from_disk",
|
||||||
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class TestUpdateWeights(unittest.TestCase):
|
|||||||
|
|
||||||
def run_update_weights(self, model_path):
|
def run_update_weights(self, model_path):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/update_weights",
|
self.base_url + "/update_weights_from_disk",
|
||||||
json={
|
json={
|
||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user