Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
world_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLTrainerSendWeightsArgs:
|
||||
"""Arguments for NCCL trainer_send_weights method."""
|
||||
|
||||
group: Any
|
||||
"""Process group (PyNcclCommunicator) for NCCL communication."""
|
||||
src: int = 0
|
||||
"""Source rank (default 0, trainer is typically rank 0)."""
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
|
||||
"""Optional function to apply to each (name, tensor) pair before broadcasting.
|
||||
If None, extracts just the tensor."""
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
stream: torch.cuda.Stream | None = None
|
||||
"""CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for NCCL weight transfer backend."""
|
||||
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int = 0,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
||||
| None = None,
|
||||
packed: bool = False,
|
||||
stream: torch.cuda.Stream | None = None,
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""Broadcast weights from trainer to vLLM workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (default 0, trainer is typically rank 0)
|
||||
post_iter_func: Optional function to apply to each (name, tensor) pair
|
||||
before broadcasting. If None, extracts just the tensor.
|
||||
packed: Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before
|
||||
broadcasting to reduce NCCL communication overhead.
|
||||
stream: CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer.
|
||||
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
packed_num_buffers: Number of buffers for double/triple buffering.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
|
||||
NCCL-specific arguments. If a dict, should contain keys from
|
||||
NCCLTrainerSendWeightsArgs.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... NCCLTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
||||
... param_iter, group, packed=True
|
||||
... )
|
||||
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
|
||||
"""
|
||||
if post_iter_func is None:
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = NCCLTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
if args.post_iter_func is None:
|
||||
# Default: extract just the tensor from (name, tensor) tuple
|
||||
post_iter_func = lambda x: x[1]
|
||||
else:
|
||||
post_iter_func = args.post_iter_func
|
||||
|
||||
if packed:
|
||||
if args.packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=group,
|
||||
src=src,
|
||||
group=args.group,
|
||||
src=args.src,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=packed_buffer_size_bytes,
|
||||
num_buffers=packed_num_buffers,
|
||||
buffer_size_bytes=args.packed_buffer_size_bytes,
|
||||
num_buffers=args.packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for item in iterator:
|
||||
tensor = post_iter_func(item)
|
||||
group.broadcast(
|
||||
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
||||
args.group.broadcast(
|
||||
tensor,
|
||||
src=args.src,
|
||||
stream=args.stream or torch.cuda.current_stream(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user