diff --git a/.gitignore b/.gitignore index 6d0987f27..73fd52992 100644 --- a/.gitignore +++ b/.gitignore @@ -220,3 +220,5 @@ work_dirs/ *.app compile_commands.json + +*.iml diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c5884b5f0..202112cc8 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -21,6 +21,8 @@ from dataclasses import dataclass from enum import Enum from typing import Dict, List, Optional, Tuple, Union +import torch + from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput: message: str +@dataclass +class UpdateWeightsFromTensorReqInput: + name: str + tensor: torch.Tensor + + +@dataclass +class UpdateWeightsFromTensorReqOutput: + success: bool + message: str + + @dataclass class InitWeightsUpdateGroupReqInput: # The master address diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7feaaedb8..c70c61e4c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -22,7 +22,7 @@ import warnings from collections import deque from concurrent import futures from types import SimpleNamespace -from typing import Callable, Dict, List, Optional, Tuple +from typing import Dict, List, Optional import psutil import setproctitle @@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -478,6 +480,11 @@ class Scheduler: self.send_to_tokenizer.send_pyobj( UpdateWeightsFromDistributedReqOutput(success, message) ) + elif isinstance(recv_req, UpdateWeightsFromTensorReqInput): + success, message = self.update_weights_from_tensor(recv_req) + self.send_to_tokenizer.send_pyobj( + UpdateWeightsFromTensorReqOutput(success, message) + ) elif isinstance(recv_req, GetWeightsByNameReqInput): parameter = self.get_weights_by_name(recv_req) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) @@ -1458,6 +1465,17 @@ class Scheduler: logger.error(message) return success, message + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + """Update the online model parameter from tensors.""" + success, message = self.tp_worker.update_weights_from_tensor(recv_req) + # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return parameter diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 020e96e65..b98bd09fc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -179,6 +181,9 @@ class TokenizerManager: self.update_weights_from_distributed_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_weights_from_tensor_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -515,6 +520,22 @@ class TokenizerManager: result = (await self.update_weights_from_distributed_communicator(obj))[0] return result.success, result.message + async def update_weights_from_tensor( + self, + obj: UpdateWeightsFromTensorReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message + async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None ): @@ -708,6 +729,11 @@ class TokenizerManager: self.server_args.dp_size == 1 ), "dp_size must be 1 for update weights from distributed" self.update_weights_from_distributed_communicator.handle_recv(recv_obj) + elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput): + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for update weights from distributed" + self.update_weights_from_tensor_communicator.handle_recv(recv_obj) elif isinstance(recv_obj, GetWeightsByNameReqOutput): self.get_weights_by_name_communicator.handle_recv(recv_obj) else: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 3aa06b4b8..ce284c04a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -188,6 +189,12 @@ class TpModelWorker: ) return success, message + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + success, message = self.model_runner.update_weights_from_tensor( + recv_req.name, recv_req.tensor + ) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index a9db18783..4600bf99a 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker @@ -225,6 +226,10 @@ class TpModelWorkerClient: success, message = self.worker.update_weights_from_distributed(recv_req) return success, message + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + success, message = self.worker.update_weights_from_tensor(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return self.worker.get_weights_by_name(recv_req) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2612f8840..7cb7d5da7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -429,6 +429,10 @@ class ModelRunner: logger.error(error_msg) return False, error_msg + def update_weights_from_tensor(self, name, tensor: torch.Tensor): + self.model.load_weights([(name, tensor)]) + return True, "Success" # TODO error handling + def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0b51a0636..92df9b8bf 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -109,6 +110,7 @@ app.add_middleware( tokenizer_manager: TokenizerManager = None scheduler_info: Dict = None + ##### Native API endpoints ##### @@ -866,6 +868,14 @@ class Engine: tokenizer_manager.update_weights_from_distributed(obj, None) ) + def update_weights_from_tensor(self, name, tensor): + """Update weights from distributed source.""" + obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + tokenizer_manager.update_weights_from_tensor(obj, None) + ) + def get_weights_by_name(self, name, truncate_size=100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index df0d41476..b48ee7b23 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -40,6 +40,7 @@ suites = { "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", + "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_session_control.py", diff --git a/test/srt/test_update_weights_from_tensor.py b/test/srt/test_update_weights_from_tensor.py new file mode 100644 index 000000000..7cca98a0f --- /dev/null +++ b/test/srt/test_update_weights_from_tensor.py @@ -0,0 +1,32 @@ +import unittest + +import torch + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestReleaseGPUOccupation(unittest.TestCase): + def test_release_and_resume_occupation(self): + engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + param_name = "model.layers.2.self_attn.k_proj.weight" + + def _check_param(expect_values): + actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] + assert torch.allclose( + actual_values, torch.tensor(expect_values), atol=0.001 + ), f"{actual_values=}" + + _check_param([0.0571, -0.0114, 0.0444, 0.0215, -0.0149]) + + new_tensor = torch.full((3072, 2048), 1.5) + engine.update_weights_from_tensor(param_name, new_tensor) + + _check_param([1.5] * 5) + + engine.shutdown() + + +if __name__ == "__main__": + unittest.main()