Speed up update_weights_from_tensor (#2695)
This commit is contained in:
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
name: str
|
||||
tensor: torch.Tensor
|
||||
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,7 +197,7 @@ class TpModelWorker:
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
recv_req.name, recv_req.tensor
|
||||
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import gc
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -428,9 +428,9 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||
self.model.load_weights([(name, tensor)])
|
||||
return True, "Success" # TODO error handling
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(named_tensors)
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
|
||||
@@ -27,7 +27,9 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
assert_pkg_version,
|
||||
@@ -874,9 +877,11 @@ class Engine:
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor):
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import io
|
||||
import ipaddress
|
||||
import itertools
|
||||
import json
|
||||
@@ -34,6 +35,7 @@ import warnings
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -60,7 +62,6 @@ from triton.runtime.cache import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) ->
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs):
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return text, call_info_list
|
||||
|
||||
|
||||
class MultiprocessingSerializer:
|
||||
@staticmethod
|
||||
def serialize(obj):
|
||||
buf = io.BytesIO()
|
||||
ForkingPickler(buf).dump(obj)
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data):
|
||||
return ForkingPickler.loads(data)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -6,27 +7,32 @@ import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class TestReleaseGPUOccupation(unittest.TestCase):
|
||||
def test_release_and_resume_occupation(self):
|
||||
class TestUpdateWeightsFromTensor(unittest.TestCase):
|
||||
def test_update_weights_from_tensor(self):
|
||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
param_name = "model.layers.2.self_attn.k_proj.weight"
|
||||
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
|
||||
|
||||
def _check_param(expect_values):
|
||||
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||
assert torch.allclose(
|
||||
actual_values, torch.tensor(expect_values), atol=0.001
|
||||
), f"{actual_values=}"
|
||||
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
||||
|
||||
_check_param([0.0571, -0.0114, 0.0444, 0.0215, -0.0149])
|
||||
new_tensor = torch.full((16384, 2048), 1.5)
|
||||
|
||||
new_tensor = torch.full((3072, 2048), 1.5)
|
||||
engine.update_weights_from_tensor(param_name, new_tensor)
|
||||
time_start = time.time()
|
||||
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
|
||||
print(f"Time delta: {time.time() - time_start:.03f}")
|
||||
|
||||
_check_param([1.5] * 5)
|
||||
for param_name in param_names[:3]:
|
||||
_check_param(engine, param_name, [1.5] * 5)
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def _check_param(engine, param_name, expect_values):
|
||||
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||
assert torch.allclose(
|
||||
actual_values, torch.tensor(expect_values), atol=0.002
|
||||
), f"{actual_values=}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user