Speed up update_weights_from_tensor (#2695)
This commit is contained in:
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
name: str
|
||||
tensor: torch.Tensor
|
||||
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,7 +197,7 @@ class TpModelWorker:
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
recv_req.name, recv_req.tensor
|
||||
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import gc
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -428,9 +428,9 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||
self.model.load_weights([(name, tensor)])
|
||||
return True, "Success" # TODO error handling
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(named_tensors)
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
|
||||
@@ -27,7 +27,9 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
assert_pkg_version,
|
||||
@@ -874,9 +877,11 @@ class Engine:
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor):
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import io
|
||||
import ipaddress
|
||||
import itertools
|
||||
import json
|
||||
@@ -34,6 +35,7 @@ import warnings
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -60,7 +62,6 @@ from triton.runtime.cache import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) ->
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs):
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return text, call_info_list
|
||||
|
||||
|
||||
class MultiprocessingSerializer:
|
||||
@staticmethod
|
||||
def serialize(obj):
|
||||
buf = io.BytesIO()
|
||||
ForkingPickler(buf).dump(obj)
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data):
|
||||
return ForkingPickler.loads(data)
|
||||
|
||||
Reference in New Issue
Block a user