udate weights from disk (#2265)

This commit is contained in:
Chayenne
2024-11-29 17:17:00 -08:00
committed by GitHub
parent b53d6cbda3
commit 7d5d1d3d29
11 changed files with 54 additions and 40 deletions

View File

@@ -352,7 +352,7 @@ class FlushCacheReq:
@dataclass
class UpdateWeightReqInput:
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
@@ -360,7 +360,7 @@ class UpdateWeightReqInput:
@dataclass
class UpdateWeightReqOutput:
class UpdateWeightFromDiskReqOutput:
success: bool
message: str

View File

@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -506,10 +506,10 @@ class Scheduler:
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
@@ -1363,9 +1363,9 @@ class Scheduler:
req.to_abort = True
break
def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"

View File

@@ -25,6 +25,7 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -405,8 +406,10 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
@@ -520,10 +523,13 @@ class TokenizerManager:
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput):
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1

View File

@@ -19,7 +19,7 @@ from typing import Optional
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -155,8 +155,8 @@ class TpModelWorker:
embeddings = logits_output.embeddings
return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message

View File

@@ -23,7 +23,7 @@ from typing import Optional
import psutil
import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
@@ -204,8 +204,8 @@ class TpModelWorkerClient:
) % self.future_token_ids_limit
return None, future_next_token_ids
def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.worker.update_weights(recv_req)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def __delete__(self):

View File

@@ -20,10 +20,13 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from typing import Optional, Type
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -319,8 +322,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
@@ -329,7 +332,7 @@ class ModelRunner:
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"Update weights begin. "
f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)

View File

@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
OpenSessionReqInput,
UpdateWeightReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -192,11 +192,11 @@ async def stop_profile_async():
)
@app.post("/update_weights")
@app.post("/update_weights_from_disk")
@time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request)
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(

View File

@@ -424,6 +424,7 @@ def popen_launch_server(
port,
*other_args,
]
if api_key:
command += ["--api-key", api_key]