Add update_weights_from_tensor (#2631)

This commit is contained in:
fzyzcjy
2024-12-29 05:30:27 +08:00
committed by GitHub
parent 7863e4368a
commit fd28640dc5
10 changed files with 120 additions and 1 deletions

View File

@@ -21,6 +21,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
import torch
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput:
message: str
@dataclass
class UpdateWeightsFromTensorReqInput:
name: str
tensor: torch.Tensor
@dataclass
class UpdateWeightsFromTensorReqOutput:
success: bool
message: str
@dataclass
class InitWeightsUpdateGroupReqInput:
# The master address

View File

@@ -22,7 +22,7 @@ import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import psutil
import setproctitle
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -478,6 +480,11 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
success, message = self.update_weights_from_tensor(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromTensorReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
@@ -1458,6 +1465,17 @@ class Scheduler:
logger.error(message)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter

View File

@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -179,6 +181,9 @@ class TokenizerManager:
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -515,6 +520,22 @@ class TokenizerManager:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor(
self,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
@@ -708,6 +729,11 @@ class TokenizerManager:
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
else:

View File

@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -188,6 +189,12 @@ class TpModelWorker:
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
recv_req.name, recv_req.tensor
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size

View File

@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.worker.update_weights_from_tensor(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)

View File

@@ -429,6 +429,10 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
self.model.load_weights([(name, tensor)])
return True, "Success" # TODO error handling
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:

View File

@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -109,6 +110,7 @@ app.add_middleware(
tokenizer_manager: TokenizerManager = None
scheduler_info: Dict = None
##### Native API endpoints #####
@@ -866,6 +868,14 @@ class Engine:
tokenizer_manager.update_weights_from_distributed(obj, None)
)
def update_weights_from_tensor(self, name, tensor):
"""Update weights from distributed source."""
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
tokenizer_manager.update_weights_from_tensor(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)