Add update_weights_from_tensor (#2631)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -220,3 +220,5 @@ work_dirs/
|
|||||||
*.app
|
*.app
|
||||||
|
|
||||||
compile_commands.json
|
compile_commands.json
|
||||||
|
|
||||||
|
*.iml
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
|
||||||
@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromTensorReqInput:
|
||||||
|
name: str
|
||||||
|
tensor: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromTensorReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsUpdateGroupReqInput:
|
class InitWeightsUpdateGroupReqInput:
|
||||||
# The master address
|
# The master address
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import warnings
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import setproctitle
|
import setproctitle
|
||||||
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
@@ -478,6 +480,11 @@ class Scheduler:
|
|||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||||
)
|
)
|
||||||
|
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
||||||
|
success, message = self.update_weights_from_tensor(recv_req)
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
UpdateWeightsFromTensorReqOutput(success, message)
|
||||||
|
)
|
||||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||||
parameter = self.get_weights_by_name(recv_req)
|
parameter = self.get_weights_by_name(recv_req)
|
||||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||||
@@ -1458,6 +1465,17 @@ class Scheduler:
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
"""Update the online model parameter from tensors."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||||
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||||
|
if success:
|
||||||
|
flash_cache_success = self.flush_cache()
|
||||||
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||||
return parameter
|
return parameter
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -179,6 +181,9 @@ class TokenizerManager:
|
|||||||
self.update_weights_from_distributed_communicator = _Communicator(
|
self.update_weights_from_distributed_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.update_weights_from_tensor_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.get_weights_by_name_communicator = _Communicator(
|
self.get_weights_by_name_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -515,6 +520,22 @@ class TokenizerManager:
|
|||||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def update_weights_from_tensor(
|
||||||
|
self,
|
||||||
|
obj: UpdateWeightsFromTensorReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be for update weights from distributed"
|
||||||
|
|
||||||
|
# This means that weight sync
|
||||||
|
# cannot run while requests are in progress.
|
||||||
|
async with self.model_update_lock.writer_lock:
|
||||||
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
@@ -708,6 +729,11 @@ class TokenizerManager:
|
|||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1
|
||||||
), "dp_size must be 1 for update weights from distributed"
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
||||||
|
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
|
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
||||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -188,6 +189,12 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
success, message = self.model_runner.update_weights_from_tensor(
|
||||||
|
recv_req.name, recv_req.tensor
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.model_runner.get_weights_by_name(
|
parameter = self.model_runner.get_weights_by_name(
|
||||||
recv_req.name, recv_req.truncate_size
|
recv_req.name, recv_req.truncate_size
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
|
|||||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
success, message = self.worker.update_weights_from_tensor(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
return self.worker.get_weights_by_name(recv_req)
|
return self.worker.get_weights_by_name(recv_req)
|
||||||
|
|
||||||
|
|||||||
@@ -429,6 +429,10 @@ class ModelRunner:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||||
|
self.model.load_weights([(name, tensor)])
|
||||||
|
return True, "Success" # TODO error handling
|
||||||
|
|
||||||
def get_weights_by_name(
|
def get_weights_by_name(
|
||||||
self, name: str, truncate_size: int = 100
|
self, name: str, truncate_size: int = 100
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -109,6 +110,7 @@ app.add_middleware(
|
|||||||
tokenizer_manager: TokenizerManager = None
|
tokenizer_manager: TokenizerManager = None
|
||||||
scheduler_info: Dict = None
|
scheduler_info: Dict = None
|
||||||
|
|
||||||
|
|
||||||
##### Native API endpoints #####
|
##### Native API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@@ -866,6 +868,14 @@ class Engine:
|
|||||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, name, tensor):
|
||||||
|
"""Update weights from distributed source."""
|
||||||
|
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_by_name(self, name, truncate_size=100):
|
def get_weights_by_name(self, name, truncate_size=100):
|
||||||
"""Get weights by parameter name."""
|
"""Get weights by parameter name."""
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ suites = {
|
|||||||
"test_triton_attention_kernels.py",
|
"test_triton_attention_kernels.py",
|
||||||
"test_triton_attention_backend.py",
|
"test_triton_attention_backend.py",
|
||||||
"test_update_weights_from_disk.py",
|
"test_update_weights_from_disk.py",
|
||||||
|
"test_update_weights_from_tensor.py",
|
||||||
"test_vision_chunked_prefill.py",
|
"test_vision_chunked_prefill.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_session_control.py",
|
"test_session_control.py",
|
||||||
|
|||||||
32
test/srt/test_update_weights_from_tensor.py
Normal file
32
test/srt/test_update_weights_from_tensor.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class TestReleaseGPUOccupation(unittest.TestCase):
|
||||||
|
def test_release_and_resume_occupation(self):
|
||||||
|
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
|
param_name = "model.layers.2.self_attn.k_proj.weight"
|
||||||
|
|
||||||
|
def _check_param(expect_values):
|
||||||
|
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||||
|
assert torch.allclose(
|
||||||
|
actual_values, torch.tensor(expect_values), atol=0.001
|
||||||
|
), f"{actual_values=}"
|
||||||
|
|
||||||
|
_check_param([0.0571, -0.0114, 0.0444, 0.0215, -0.0149])
|
||||||
|
|
||||||
|
new_tensor = torch.full((3072, 2048), 1.5)
|
||||||
|
engine.update_weights_from_tensor(param_name, new_tensor)
|
||||||
|
|
||||||
|
_check_param([1.5] * 5)
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user