316 lines
12 KiB
Python
316 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""NCCL-based weight transfer engine."""
|
|
|
|
from collections.abc import Callable, Iterator
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
|
|
from vllm.config.parallel import ParallelConfig
|
|
from vllm.config.weight_transfer import WeightTransferConfig
|
|
from vllm.distributed.weight_transfer.base import (
|
|
WeightTransferEngine,
|
|
WeightTransferInitInfo,
|
|
WeightTransferUpdateInfo,
|
|
)
|
|
from vllm.distributed.weight_transfer.packed_tensor import (
|
|
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
|
DEFAULT_PACKED_NUM_BUFFERS,
|
|
packed_broadcast_consumer,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
|
"""Initialization info for NCCL weight transfer backend."""
|
|
|
|
master_address: str
|
|
master_port: int
|
|
rank_offset: int
|
|
world_size: int
|
|
|
|
|
|
@dataclass
|
|
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
|
"""Update info for NCCL weight transfer backend."""
|
|
|
|
names: list[str]
|
|
dtype_names: list[str]
|
|
shapes: list[list[int]]
|
|
packed: bool = False
|
|
"""Whether to use packed tensor broadcasting for efficiency.
|
|
When True, multiple tensors are batched together before broadcasting
|
|
to reduce NCCL communication overhead."""
|
|
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
|
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
|
Both producer and consumer must use the same value."""
|
|
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
|
"""Number of buffers for double/triple buffering during packed transfer.
|
|
Both producer and consumer must use the same value."""
|
|
|
|
def __post_init__(self):
|
|
"""Validate that all lists have the same length."""
|
|
num_params = len(self.names)
|
|
if len(self.dtype_names) != num_params:
|
|
raise ValueError(
|
|
f"`dtype_names` should be of the same size as `names`: "
|
|
f"got {len(self.dtype_names)} and {len(self.names)}"
|
|
)
|
|
if len(self.shapes) != num_params:
|
|
raise ValueError(
|
|
f"`shapes` should be of the same size as `names`: "
|
|
f"got {len(self.shapes)} and {len(self.names)}"
|
|
)
|
|
|
|
|
|
class NCCLWeightTransferEngine(
|
|
WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo]
|
|
):
|
|
"""
|
|
Weight transfer engine using NCCL for communication between trainer and workers.
|
|
|
|
This implementation uses NCCL broadcast operations to transfer weights from
|
|
the trainer (rank 0) to all inference workers in a process group.
|
|
"""
|
|
|
|
# Define backend-specific dataclass types
|
|
init_info_cls = NCCLWeightTransferInitInfo
|
|
update_info_cls = NCCLWeightTransferUpdateInfo
|
|
|
|
def __init__(
|
|
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
|
) -> None:
|
|
"""
|
|
Initialize the NCCL weight transfer engine.
|
|
|
|
Args:
|
|
config: The configuration for the weight transfer engine
|
|
parallel_config: The configuration for the parallel setup
|
|
"""
|
|
super().__init__(config, parallel_config)
|
|
self.model_update_group: PyNcclCommunicator | None = None
|
|
|
|
def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None:
|
|
"""
|
|
Initialize NCCL process group with the trainer.
|
|
|
|
Args:
|
|
init_info: NCCL initialization info containing master address, port,
|
|
rank offset, and world size
|
|
"""
|
|
|
|
# Calculate the global rank in the trainer-worker process group
|
|
# Must account for data parallel to get unique ranks across all workers
|
|
dp_rank = self.parallel_config.data_parallel_rank
|
|
world_size_per_dp = self.parallel_config.world_size # TP * PP
|
|
rank_within_dp = self.parallel_config.rank
|
|
|
|
# Unique rank across all DP groups
|
|
worker_rank = dp_rank * world_size_per_dp + rank_within_dp
|
|
rank = worker_rank + init_info.rank_offset
|
|
# Create stateless process group
|
|
self.model_update_group = (
|
|
NCCLWeightTransferEngine._stateless_init_process_group(
|
|
init_info.master_address,
|
|
init_info.master_port,
|
|
rank,
|
|
init_info.world_size,
|
|
torch.cuda.current_device(),
|
|
)
|
|
)
|
|
|
|
def receive_weights(
|
|
self,
|
|
update_info: NCCLWeightTransferUpdateInfo,
|
|
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
|
) -> None:
|
|
"""
|
|
Receive weights from trainer via NCCL broadcast and load them incrementally.
|
|
|
|
If update_info.packed is True, uses packed tensor broadcasting for
|
|
efficient transfer of multiple weights in batches. Otherwise, uses simple
|
|
one-by-one broadcasting.
|
|
|
|
Args:
|
|
update_info: NCCL update info containing parameter names, dtypes, shapes,
|
|
and packed flag
|
|
load_weights: Callable that loads weights into the model. Called
|
|
incrementally for each batch of weights to avoid OOM.
|
|
"""
|
|
if self.model_update_group is None:
|
|
raise RuntimeError(
|
|
"NCCL weight transfer not initialized. "
|
|
"Call init_transfer_engine() first."
|
|
)
|
|
|
|
if update_info.packed:
|
|
# Build iterator of (name, (shape, dtype)) from update_info
|
|
def state_dict_info_iterator():
|
|
for name, dtype_name, shape in zip(
|
|
update_info.names, update_info.dtype_names, update_info.shapes
|
|
):
|
|
dtype = getattr(torch, dtype_name)
|
|
yield (name, (shape, dtype))
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=state_dict_info_iterator(),
|
|
group=self.model_update_group,
|
|
src=0,
|
|
post_unpack_func=load_weights,
|
|
buffer_size_bytes=update_info.packed_buffer_size_bytes,
|
|
num_buffers=update_info.packed_num_buffers,
|
|
)
|
|
else:
|
|
# Use simple one-by-one broadcasting
|
|
for name, dtype_name, shape in zip(
|
|
update_info.names, update_info.dtype_names, update_info.shapes
|
|
):
|
|
dtype = getattr(torch, dtype_name)
|
|
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
|
self.model_update_group.broadcast(
|
|
weight, src=0, stream=torch.cuda.current_stream()
|
|
)
|
|
load_weights([(name, weight)])
|
|
del weight
|
|
|
|
def shutdown(self) -> None:
|
|
if self.model_update_group is not None:
|
|
# Clean up the communicator by removing the reference
|
|
self.model_update_group = None
|
|
|
|
@staticmethod
|
|
def trainer_send_weights(
|
|
iterator: Iterator[tuple[str, torch.Tensor]],
|
|
group: Any,
|
|
src: int = 0,
|
|
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
|
| None = None,
|
|
packed: bool = False,
|
|
stream: torch.cuda.Stream | None = None,
|
|
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
|
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
|
) -> None:
|
|
"""Broadcast weights from trainer to vLLM workers.
|
|
|
|
Args:
|
|
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
|
group: Process group (PyNcclCommunicator)
|
|
src: Source rank (default 0, trainer is typically rank 0)
|
|
post_iter_func: Optional function to apply to each (name, tensor) pair
|
|
before broadcasting. If None, extracts just the tensor.
|
|
packed: Whether to use packed tensor broadcasting for efficiency.
|
|
When True, multiple tensors are batched together before
|
|
broadcasting to reduce NCCL communication overhead.
|
|
stream: CUDA stream to use for broadcasting if packed is False.
|
|
If packed is True, new streams will be created for each buffer.
|
|
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
|
Must match the value used in NCCLWeightTransferUpdateInfo.
|
|
packed_num_buffers: Number of buffers for double/triple buffering.
|
|
Must match the value used in NCCLWeightTransferUpdateInfo.
|
|
|
|
Example:
|
|
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
|
... NCCLWeightTransferEngine,
|
|
... )
|
|
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
|
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
|
... param_iter, group, packed=True
|
|
... )
|
|
"""
|
|
if post_iter_func is None:
|
|
# Default: extract just the tensor from (name, tensor) tuple
|
|
post_iter_func = lambda x: x[1]
|
|
|
|
if packed:
|
|
# Use packed tensor broadcasting for efficiency
|
|
from vllm.distributed.weight_transfer.packed_tensor import (
|
|
packed_broadcast_producer,
|
|
)
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iterator,
|
|
group=group,
|
|
src=src,
|
|
post_iter_func=post_iter_func,
|
|
buffer_size_bytes=packed_buffer_size_bytes,
|
|
num_buffers=packed_num_buffers,
|
|
)
|
|
else:
|
|
# Use simple one-by-one broadcasting
|
|
for item in iterator:
|
|
tensor = post_iter_func(item)
|
|
group.broadcast(
|
|
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
|
)
|
|
|
|
@staticmethod
|
|
def trainer_init(
|
|
init_info: NCCLWeightTransferInitInfo | dict,
|
|
) -> "PyNcclCommunicator":
|
|
"""
|
|
Initialize NCCL process group for trainer-side weight transfer.
|
|
|
|
The trainer is always rank 0 in the process group. Uses the current
|
|
CUDA device (torch.cuda.current_device()).
|
|
|
|
Args:
|
|
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
|
|
- master_address: str
|
|
- master_port: int
|
|
- world_size: int
|
|
|
|
Returns:
|
|
PyNcclCommunicator for weight transfer.
|
|
|
|
Example:
|
|
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
|
... NCCLWeightTransferEngine,
|
|
... )
|
|
>>> group = NCCLWeightTransferEngine.trainer_init(
|
|
... dict(
|
|
... master_address=master_address,
|
|
... master_port=master_port,
|
|
... world_size=world_size,
|
|
... ),
|
|
... )
|
|
"""
|
|
if isinstance(init_info, dict):
|
|
master_address = init_info["master_address"]
|
|
master_port = init_info["master_port"]
|
|
world_size = init_info["world_size"]
|
|
else:
|
|
# NCCLWeightTransferInitInfo object
|
|
master_address = init_info.master_address
|
|
master_port = init_info.master_port
|
|
world_size = init_info.world_size
|
|
|
|
# Trainer is always rank 0
|
|
return NCCLWeightTransferEngine._stateless_init_process_group(
|
|
master_address, master_port, 0, world_size, torch.cuda.current_device()
|
|
)
|
|
|
|
@staticmethod
|
|
def _stateless_init_process_group(
|
|
master_address, master_port, rank, world_size, device
|
|
):
|
|
"""
|
|
vLLM provides `StatelessProcessGroup` to create a process group
|
|
without considering the global process group in torch.distributed.
|
|
It is recommended to create `StatelessProcessGroup`, and then initialize
|
|
the data-plane communication (NCCL) between external (train processes)
|
|
and vLLM workers.
|
|
"""
|
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
from vllm.distributed.utils import StatelessProcessGroup
|
|
|
|
pg = StatelessProcessGroup.create(
|
|
host=master_address, port=master_port, rank=rank, world_size=world_size
|
|
)
|
|
pynccl = PyNcclCommunicator(pg, device=device)
|
|
return pynccl
|