udate weights from disk (#2265)

This commit is contained in:
Chayenne
2024-11-29 17:17:00 -08:00
committed by GitHub
parent b53d6cbda3
commit 7d5d1d3d29
11 changed files with 54 additions and 40 deletions

View File

@@ -25,6 +25,7 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -405,8 +406,10 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
@@ -520,10 +523,13 @@ class TokenizerManager:
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput):
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1