This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Weight transfer engines for syncing model weights from trainers
to inference workers.
"""
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
__all__ = [
"WeightTransferEngineFactory",
]

View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo")
TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo")
# Base protocols for backend-specific dataclasses
@dataclass
class WeightTransferInitInfo(ABC): # noqa: B024
"""Base class for backend-specific initialization info."""
pass
@dataclass
class WeightTransferUpdateInfo(ABC): # noqa: B024
"""Base class for backend-specific weight update info."""
_: KW_ONLY
is_checkpoint_format: bool = True
"""Set to True if weights are in checkpoint/original model format and need
layerwise processing. Set to False if weights have already been processed
into kernel format (repacking, renaming, etc.)."""
# API-level request classes (accept dicts for backend-agnostic serialization)
@dataclass
class WeightTransferInitRequest:
"""API-level weight transfer initialization request."""
init_info: dict[str, Any] = field(default_factory=dict)
@dataclass
class WeightTransferUpdateRequest:
"""API-level weight update request."""
update_info: dict[str, Any] = field(default_factory=dict)
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
"""
Base class for weight transfer engines that handle transport of model weights
from a trainer to inference workers.
This abstraction separates weight transfer transport logic from the worker
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
plugged in.
Subclasses should define:
init_info_cls: Type of backend-specific initialization info
update_info_cls: Type of backend-specific update info
"""
# Subclasses should override these class attributes
init_info_cls: type[TInitInfo]
update_info_cls: type[TUpdateInfo]
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
self.config = config
self.parallel_config = parallel_config
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
"""
Construct typed init info from dict with validation.
Args:
init_dict: Dictionary containing backend-specific initialization parameters
Returns:
Typed backend-specific init info dataclass
Raises:
ValueError: If init_dict is invalid for this backend
"""
try:
return self.init_info_cls(**init_dict)
except TypeError as e:
raise ValueError(
f"Invalid init_info for {self.__class__.__name__}: {e}"
) from e
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
"""
Construct typed update info from dict with validation.
Args:
update_dict: Dictionary containing backend-specific update parameters
Returns:
Typed backend-specific update info dataclass
Raises:
ValueError: If update_dict is invalid for this backend
"""
try:
return self.update_info_cls(**update_dict)
except TypeError as e:
raise ValueError(
f"Invalid update_info for {self.__class__.__name__}: {e}"
) from e
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
Args:
init_info: Backend-specific initialization info
"""
raise NotImplementedError
@abstractmethod
def receive_weights(
self,
update_info: TUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer and load them incrementally.
Args:
update_info: Backend-specific update info containing parameter metadata
and any backend-specific data
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
This should be called when the worker is shutting down.
"""
raise NotImplementedError

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Factory for weight transfer engines with lazy loading."""
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
from vllm.distributed.weight_transfer.base import WeightTransferEngine
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
logger = init_logger(__name__)
class WeightTransferEngineFactory:
"""Factory for creating weight transfer engines with lazy loading.
This factory implements a registry pattern that supports:
- Lazy loading: Engine modules are only imported when actually needed
- Extensibility: Custom engines can be registered at runtime
- Centralized registration: All built-in engines registered in one place
"""
_registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {}
@classmethod
def register_engine(
cls,
name: str,
module_path_or_cls: str | type[WeightTransferEngine],
class_name: str | None = None,
) -> None:
"""Register an engine with lazy-loading or direct class reference.
Supports two calling conventions:
1. Lazy loading: register_engine(name, module_path, class_name)
2. Direct class: register_engine(name, engine_cls)
Args:
name: The name to register the engine under (e.g., "nccl")
module_path_or_cls: Either a module path string for lazy loading,
or the engine class directly
class_name: Name of the engine class (required if module_path is string)
Raises:
ValueError: If an engine with the same name is already registered
"""
if name in cls._registry:
raise ValueError(f"Weight transfer engine '{name}' is already registered.")
if isinstance(module_path_or_cls, str):
# Lazy loading path
module_path = module_path_or_cls
if class_name is None:
raise ValueError(
"class_name is required when registering with module path"
)
def loader() -> type[WeightTransferEngine]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
else:
# Direct class registration
engine_cls = module_path_or_cls
cls._registry[name] = lambda: engine_cls
@classmethod
def create_engine(
cls,
config: "WeightTransferConfig",
parallel_config: "ParallelConfig",
) -> WeightTransferEngine:
"""Create a weight transfer engine instance.
Args:
config: Weight transfer configuration containing the backend name
parallel_config: Parallel configuration for the engine
Returns:
An initialized weight transfer engine instance
Raises:
ValueError: If the backend is not registered
"""
backend = config.backend
if backend not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(
f"Invalid weight transfer backend: {backend}. "
f"Available engines: {available}"
)
engine_cls = cls._registry[backend]()
logger.info(
"Creating weight transfer engine: %s",
engine_cls.__name__,
)
return engine_cls(config, parallel_config)
# Register built-in weight transfer engines here.
# Registration should be centralized to ensure lazy loading -
# engine modules are only imported when actually used.
WeightTransferEngineFactory.register_engine(
"nccl",
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)

View File

@@ -0,0 +1,315 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NCCL-based weight transfer engine."""
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
if TYPE_CHECKING:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
from vllm.distributed.weight_transfer.packed_tensor import (
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
DEFAULT_PACKED_NUM_BUFFERS,
packed_broadcast_consumer,
)
@dataclass
class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
"""Initialization info for NCCL weight transfer backend."""
master_address: str
master_port: int
rank_offset: int
world_size: int
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
packed: bool = False
"""Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer. Default is 1GB.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
Both producer and consumer must use the same value."""
def __post_init__(self):
"""Validate that all lists have the same length."""
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
f"`dtype_names` should be of the same size as `names`: "
f"got {len(self.dtype_names)} and {len(self.names)}"
)
if len(self.shapes) != num_params:
raise ValueError(
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
class NCCLWeightTransferEngine(
WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo]
):
"""
Weight transfer engine using NCCL for communication between trainer and workers.
This implementation uses NCCL broadcast operations to transfer weights from
the trainer (rank 0) to all inference workers in a process group.
"""
# Define backend-specific dataclass types
init_info_cls = NCCLWeightTransferInitInfo
update_info_cls = NCCLWeightTransferUpdateInfo
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the NCCL weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
super().__init__(config, parallel_config)
self.model_update_group: PyNcclCommunicator | None = None
def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None:
"""
Initialize NCCL process group with the trainer.
Args:
init_info: NCCL initialization info containing master address, port,
rank offset, and world size
"""
# Calculate the global rank in the trainer-worker process group
# Must account for data parallel to get unique ranks across all workers
dp_rank = self.parallel_config.data_parallel_rank
world_size_per_dp = self.parallel_config.world_size # TP * PP
rank_within_dp = self.parallel_config.rank
# Unique rank across all DP groups
worker_rank = dp_rank * world_size_per_dp + rank_within_dp
rank = worker_rank + init_info.rank_offset
# Create stateless process group
self.model_update_group = (
NCCLWeightTransferEngine._stateless_init_process_group(
init_info.master_address,
init_info.master_port,
rank,
init_info.world_size,
torch.cuda.current_device(),
)
)
def receive_weights(
self,
update_info: NCCLWeightTransferUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from trainer via NCCL broadcast and load them incrementally.
If update_info.packed is True, uses packed tensor broadcasting for
efficient transfer of multiple weights in batches. Otherwise, uses simple
one-by-one broadcasting.
Args:
update_info: NCCL update info containing parameter names, dtypes, shapes,
and packed flag
load_weights: Callable that loads weights into the model. Called
incrementally for each batch of weights to avoid OOM.
"""
if self.model_update_group is None:
raise RuntimeError(
"NCCL weight transfer not initialized. "
"Call init_transfer_engine() first."
)
if update_info.packed:
# Build iterator of (name, (shape, dtype)) from update_info
def state_dict_info_iterator():
for name, dtype_name, shape in zip(
update_info.names, update_info.dtype_names, update_info.shapes
):
dtype = getattr(torch, dtype_name)
yield (name, (shape, dtype))
packed_broadcast_consumer(
iterator=state_dict_info_iterator(),
group=self.model_update_group,
src=0,
post_unpack_func=load_weights,
buffer_size_bytes=update_info.packed_buffer_size_bytes,
num_buffers=update_info.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for name, dtype_name, shape in zip(
update_info.names, update_info.dtype_names, update_info.shapes
):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
load_weights([(name, weight)])
del weight
def shutdown(self) -> None:
if self.model_update_group is not None:
# Clean up the communicator by removing the reference
self.model_update_group = None
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int = 0,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
| None = None,
packed: bool = False,
stream: torch.cuda.Stream | None = None,
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
group: Process group (PyNcclCommunicator)
src: Source rank (default 0, trainer is typically rank 0)
post_iter_func: Optional function to apply to each (name, tensor) pair
before broadcasting. If None, extracts just the tensor.
packed: Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before
broadcasting to reduce NCCL communication overhead.
stream: CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer.
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo.
packed_num_buffers: Number of buffers for double/triple buffering.
Must match the value used in NCCLWeightTransferUpdateInfo.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> NCCLWeightTransferEngine.trainer_send_weights(
... param_iter, group, packed=True
... )
"""
if post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
if packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
)
packed_broadcast_producer(
iterator=iterator,
group=group,
src=src,
post_iter_func=post_iter_func,
buffer_size_bytes=packed_buffer_size_bytes,
num_buffers=packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
group.broadcast(
tensor, src=src, stream=stream or torch.cuda.current_stream()
)
@staticmethod
def trainer_init(
init_info: NCCLWeightTransferInitInfo | dict,
) -> "PyNcclCommunicator":
"""
Initialize NCCL process group for trainer-side weight transfer.
The trainer is always rank 0 in the process group. Uses the current
CUDA device (torch.cuda.current_device()).
Args:
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
- master_address: str
- master_port: int
- world_size: int
Returns:
PyNcclCommunicator for weight transfer.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... )
>>> group = NCCLWeightTransferEngine.trainer_init(
... dict(
... master_address=master_address,
... master_port=master_port,
... world_size=world_size,
... ),
... )
"""
if isinstance(init_info, dict):
master_address = init_info["master_address"]
master_port = init_info["master_port"]
world_size = init_info["world_size"]
else:
# NCCLWeightTransferInitInfo object
master_address = init_info.master_address
master_port = init_info.master_port
world_size = init_info.world_size
# Trainer is always rank 0
return NCCLWeightTransferEngine._stateless_init_process_group(
master_address, master_port, 0, world_size, torch.cuda.current_device()
)
@staticmethod
def _stateless_init_process_group(
master_address, master_port, rank, world_size, device
):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Packed tensor utilities for efficient weight transfer."""
import math
from collections.abc import Callable, Iterator
from typing import Any
import torch
# Default values for packed tensor configuration.
# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights.
DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB
DEFAULT_PACKED_NUM_BUFFERS = 2
def packed_broadcast_producer(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Broadcast tensors in a packed manner from trainer to workers.
Args:
iterator: Iterator of model parameters. Returns a tuple of (name, tensor)
group: Process group (PyNcclCommunicator)
src: Source rank (0 in current implementation)
post_iter_func: Function to apply to each (name, tensor) pair before
packing, should return a tensor
buffer_size_bytes: Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value.
num_buffers: Number of buffers for double/triple buffering.
Both producer and consumer must use the same value.
"""
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
buffer_idx = 0
packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
packed_tensors: list[torch.Tensor] = [
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
]
while True:
# Synchronize the current stream
streams[buffer_idx].synchronize()
# Start tasks for the new buffer in a new stream
with torch.cuda.stream(streams[buffer_idx]):
try:
# Initialize the packing tensor list and sizes
packing_tensor_list[buffer_idx] = []
packing_tensor_sizes[buffer_idx] = 0
# Pack the tensors
while True:
# Apply post processing and convert to linearized uint8 tensor
tensor = (
post_iter_func(next(iterator))
.contiguous()
.view(torch.uint8)
.view(-1)
)
packing_tensor_list[buffer_idx].append(tensor)
packing_tensor_sizes[buffer_idx] += tensor.numel()
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
break
# Pack the tensors and call broadcast collective
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
except StopIteration:
# Do the last broadcast if there are remaining tensors
if len(packing_tensor_list[buffer_idx]) > 0:
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
break
def packed_broadcast_consumer(
iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
group: Any,
src: int,
post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None],
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Consume packed tensors and unpack them into a list of tensors.
Args:
iterator: Iterator of parameter metadata. Returns (name, (shape, dtype))
group: Process group (PyNcclCommunicator)
src: Source rank (0 in current implementation)
post_unpack_func: Function to apply to each list of (name, tensor) after
unpacking
buffer_size_bytes: Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value.
num_buffers: Number of buffers for double/triple buffering.
Both producer and consumer must use the same value.
"""
def unpack_tensor(
packed_tensor: torch.Tensor,
names: list[str],
shapes: list[list[int]],
dtypes: list[torch.dtype],
tensor_sizes: list[int],
) -> list[tuple[str, torch.Tensor]]:
"""Unpack a single tensor into a list of tensors.
Args:
packed_tensor: The packed torch.uint8 tensor to unpack
names: List of tensor names
shapes: List of tensor shapes
dtypes: List of tensor dtypes
tensor_sizes: List of tensor sizes in bytes
Returns:
unpacked List[(name, tensor)]
"""
unpacked_tensors = packed_tensor.split(tensor_sizes)
unpacked_list = [
(name, tensor.contiguous().view(dtype).view(*shape))
for name, shape, dtype, tensor in zip(
names, shapes, dtypes, unpacked_tensors
)
]
return unpacked_list
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
buffer_idx = 0
packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [
[] for _ in range(num_buffers)
]
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
packed_tensors: list[torch.Tensor] = [
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
]
while True:
# Synchronize the current stream
streams[buffer_idx].synchronize()
with torch.cuda.stream(streams[buffer_idx]):
# Initialize the packing tensor meta data
packing_tensor_meta_data[buffer_idx] = []
packing_tensor_sizes[buffer_idx] = 0
try:
# Form a packed tensor
while True:
name, (shape, dtype) = next(iterator)
tensor_size = math.prod(shape) * dtype.itemsize
packing_tensor_meta_data[buffer_idx].append(
(name, shape, dtype, tensor_size)
)
packing_tensor_sizes[buffer_idx] += tensor_size
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
break
# Create a packed tensor and broadcast it
packed_tensors[buffer_idx] = torch.empty(
packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda"
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Load the packed tensor into the model
names, shapes, dtypes, tensor_sizes = zip(
*packing_tensor_meta_data[buffer_idx]
)
post_unpack_func(
unpack_tensor(
packed_tensors[buffer_idx],
list(names),
list(shapes),
list(dtypes),
list(tensor_sizes),
)
)
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
except StopIteration:
# Do the last broadcast if there are remaining tensors
if len(packing_tensor_meta_data[buffer_idx]) > 0:
# Create a packed tensor and broadcast it
packed_tensors[buffer_idx] = torch.empty(
packing_tensor_sizes[buffer_idx],
dtype=torch.uint8,
device="cuda",
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Load the packed tensor into the model
names, shapes, dtypes, tensor_sizes = zip(
*packing_tensor_meta_data[buffer_idx]
)
post_unpack_func(
unpack_tensor(
packed_tensors[buffer_idx],
list(names),
list(shapes),
list(dtypes),
list(tensor_sizes),
)
)
break