Add update_weights_from_tensor (#2631)
This commit is contained in:
@@ -21,6 +21,8 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
name: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqInput:
|
||||
# The master address
|
||||
|
||||
@@ -22,7 +22,7 @@ import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
@@ -478,6 +480,11 @@ class Scheduler:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.update_weights_from_tensor(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromTensorReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
@@ -1458,6 +1465,17 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
"""Update the online model parameter from tensors."""
|
||||
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -179,6 +181,9 @@ class TokenizerManager:
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -515,6 +520,22 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_tensor(
|
||||
self,
|
||||
obj: UpdateWeightsFromTensorReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -708,6 +729,11 @@ class TokenizerManager:
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||
else:
|
||||
|
||||
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -188,6 +189,12 @@ class TpModelWorker:
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
recv_req.name, recv_req.tensor
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.worker.update_weights_from_tensor(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
|
||||
@@ -429,6 +429,10 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||
self.model.load_weights([(name, tensor)])
|
||||
return True, "Success" # TODO error handling
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -109,6 +110,7 @@ app.add_middleware(
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
scheduler_info: Dict = None
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -866,6 +868,14 @@ class Engine:
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
)
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
Reference in New Issue
Block a user