udate weights from disk (#2265)
This commit is contained in:
@@ -352,7 +352,7 @@ class FlushCacheReq:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightReqInput:
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
# The format to load the weights
|
||||
@@ -360,7 +360,7 @@ class UpdateWeightReqInput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightReqOutput:
|
||||
class UpdateWeightFromDiskReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
@@ -506,10 +506,10 @@ class Scheduler:
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
self.abort_request(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||
success, message = self.update_weights(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
||||
success, message = self.update_weights_from_disk(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightReqOutput(success, message)
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
@@ -1363,9 +1363,9 @@ class Scheduler:
|
||||
req.to_abort = True
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
"""In-place update of the weights."""
|
||||
success, message = self.tp_worker.update_weights(recv_req)
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
|
||||
@@ -25,6 +25,7 @@ import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import torch
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
)
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -405,8 +406,10 @@ class TokenizerManager:
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
@@ -520,10 +523,13 @@ class TokenizerManager:
|
||||
|
||||
while True:
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Optional
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -155,8 +155,8 @@ class TpModelWorker:
|
||||
embeddings = logits_output.embeddings
|
||||
return embeddings
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
success, message = self.model_runner.update_weights(
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.model_runner.update_weights_from_disk(
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Optional
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -204,8 +204,8 @@ class TpModelWorkerClient:
|
||||
) % self.future_token_ids_limit
|
||||
return None, future_next_token_ids
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
success, message = self.worker.update_weights(recv_req)
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def __delete__(self):
|
||||
|
||||
@@ -20,10 +20,13 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -319,8 +322,8 @@ class ModelRunner:
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def update_weights(self, model_path: str, load_format: str):
|
||||
"""Update weights in-place."""
|
||||
def update_weights_from_disk(self, model_path: str, load_format: str):
|
||||
"""Update engine weights online from disk."""
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
DefaultModelLoader,
|
||||
device_loading_context,
|
||||
@@ -329,7 +332,7 @@ class ModelRunner:
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
|
||||
logger.info(
|
||||
f"Update weights begin. "
|
||||
f"Update engine weights online from disk begin. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -192,11 +192,11 @@ async def stop_profile_async():
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights")
|
||||
@app.post("/update_weights_from_disk")
|
||||
@time_func_latency
|
||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||
"""Update the weights inplace without re-launching the server."""
|
||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(
|
||||
|
||||
@@ -424,6 +424,7 @@ def popen_launch_server(
|
||||
port,
|
||||
*other_args,
|
||||
]
|
||||
|
||||
if api_key:
|
||||
command += ["--api-key", api_key]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user