Sync from v0.13
This commit is contained in:
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
@@ -0,0 +1,508 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Benchmark script for device communicators:
|
||||
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||
and SymmMemCommunicator (multimem, two-shot).
|
||||
|
||||
for NCCL symmetric memory you need to set the environment variables
|
||||
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
||||
not use fast NVLS implementation for all reduce.
|
||||
|
||||
Usage:
|
||||
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||
|
||||
Example:
|
||||
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
||||
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
register_nccl_symmetric_ops,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default sequence lengths to benchmark
|
||||
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
# Fixed hidden size and dtype for all benchmarks
|
||||
HIDDEN_SIZE = 8192
|
||||
BENCHMARK_DTYPE = torch.bfloat16
|
||||
|
||||
# CUDA graph settings
|
||||
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
||||
|
||||
|
||||
class CommunicatorBenchmark:
|
||||
"""Benchmark class for testing device communicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
cpu_group: ProcessGroup,
|
||||
sequence_lengths: list[int],
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.device = device
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# Calculate max_size_override based on largest sequence length
|
||||
max_seq_len = max(sequence_lengths)
|
||||
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
||||
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
||||
|
||||
# Initialize communicators
|
||||
self.custom_allreduce = None
|
||||
self.pynccl_comm = None
|
||||
self.symm_mem_comm = None
|
||||
self.symm_mem_comm_multimem = None
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
self._init_communicators()
|
||||
|
||||
def _init_communicators(self):
|
||||
"""Initialize all available communicators."""
|
||||
try:
|
||||
self.custom_allreduce = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=self.max_size_override,
|
||||
)
|
||||
if not self.custom_allreduce.disabled:
|
||||
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
||||
)
|
||||
self.custom_allreduce = None
|
||||
|
||||
try:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group, device=self.device
|
||||
)
|
||||
if not self.pynccl_comm.disabled:
|
||||
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
else:
|
||||
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||
self.pynccl_comm = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
||||
)
|
||||
self.pynccl_comm = None
|
||||
|
||||
# Initialize variants for SymmMemCommunicator
|
||||
try:
|
||||
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=True,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_multimem.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_multimem = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_multimem = None
|
||||
|
||||
try:
|
||||
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=False,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_two_shot.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_two_shot = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
def benchmark_allreduce(
|
||||
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark allreduce operations for all available communicators."""
|
||||
|
||||
results = {}
|
||||
|
||||
# Define communicators with their benchmark functions
|
||||
communicators = []
|
||||
|
||||
if self.custom_allreduce is not None:
|
||||
comm = self.custom_allreduce
|
||||
# CustomAllreduce one-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_1stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"1stage", # env variable value
|
||||
)
|
||||
)
|
||||
# CustomAllreduce two-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_2stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"2stage", # env variable value
|
||||
)
|
||||
)
|
||||
|
||||
if self.pynccl_comm is not None:
|
||||
comm = self.pynccl_comm
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl-symm",
|
||||
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_multimem is not None:
|
||||
comm = self.symm_mem_comm_multimem
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_multimem",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_two_shot is not None:
|
||||
comm = self.symm_mem_comm_two_shot
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_two_shot",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
# Benchmark each communicator
|
||||
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||
# Set environment variable if needed
|
||||
if env_var is not None:
|
||||
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||
else:
|
||||
# Clear the environment variable to avoid interference
|
||||
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
|
||||
return results
|
||||
|
||||
def benchmark_allreduce_single(
|
||||
self,
|
||||
sequence_length: int,
|
||||
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
|
||||
should_use_fn: Callable[[torch.Tensor], bool],
|
||||
context,
|
||||
num_warmup: int,
|
||||
num_trials: int,
|
||||
) -> float | None:
|
||||
"""Benchmark method with CUDA graph optimization."""
|
||||
try:
|
||||
# Create test tensor (2D: sequence_length x hidden_size)
|
||||
tensor = torch.randn(
|
||||
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
||||
)
|
||||
if not should_use_fn(tensor):
|
||||
return None
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
graph_input = tensor.clone()
|
||||
|
||||
# Warmup before capture
|
||||
for _ in range(3):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
# Capture the graph using context manager
|
||||
with context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
graph_pool = torch.cuda.graph_pool_handle()
|
||||
set_graph_pool_id(graph_pool)
|
||||
with torch.cuda.graph(graph, pool=graph_pool):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for _ in range(num_warmup):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_trials):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
||||
return (
|
||||
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("CUDA graph benchmark failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"CUDA graph benchmark failed for communicator: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
||||
"""Calculate speedup information for a single tensor size."""
|
||||
if not comm_results:
|
||||
return "N/A"
|
||||
|
||||
# Find the fastest communicator
|
||||
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
||||
fastest_time = comm_results[fastest_comm]
|
||||
|
||||
# Calculate speedup vs PyNccl if available
|
||||
if "pynccl" in comm_results:
|
||||
pynccl_time = comm_results["pynccl"]
|
||||
speedup = pynccl_time / fastest_time
|
||||
return f"{fastest_comm} ({speedup:.2f}x)"
|
||||
else:
|
||||
return f"{fastest_comm} (N/A)"
|
||||
|
||||
|
||||
def print_results(
|
||||
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
||||
):
|
||||
"""Print benchmark results in a formatted table."""
|
||||
|
||||
print(f"\n{'=' * 130}")
|
||||
print("Device Communicator Benchmark Results")
|
||||
print(
|
||||
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
||||
f"Hidden Size: {HIDDEN_SIZE}"
|
||||
)
|
||||
print(f"{'=' * 130}")
|
||||
|
||||
# Get all communicator names
|
||||
all_comms = set()
|
||||
for size_results in results.values():
|
||||
all_comms.update(size_results.keys())
|
||||
|
||||
all_comms = sorted(list(all_comms))
|
||||
|
||||
# Print header
|
||||
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
||||
for comm in all_comms:
|
||||
header += f"{comm:<20}"
|
||||
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
# Print results for each sequence length
|
||||
for seq_len in sequence_lengths:
|
||||
if seq_len in results:
|
||||
# Calculate tensor size in elements and bytes
|
||||
tensor_elements = seq_len * HIDDEN_SIZE
|
||||
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
||||
|
||||
# Format tensor size (MB)
|
||||
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
||||
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
||||
|
||||
# Format tensor shape
|
||||
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
||||
|
||||
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
||||
for comm in all_comms:
|
||||
if comm in results[seq_len]:
|
||||
row += f"{results[seq_len][comm]:<20.3f}"
|
||||
else:
|
||||
row += f"{'N/A':<20}"
|
||||
|
||||
# Calculate speedup information
|
||||
speedup_info = _calculate_speedup_info(results[seq_len])
|
||||
row += f"{speedup_info:<30}"
|
||||
|
||||
print(row)
|
||||
|
||||
print(f"{'=' * 130}")
|
||||
print("All times are in milliseconds (ms) per allreduce operation")
|
||||
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence-lengths",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=DEFAULT_SEQUENCE_LENGTHS,
|
||||
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
||||
)
|
||||
|
||||
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize distributed
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="gloo")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Set device
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Get CPU process group
|
||||
cpu_group = dist.new_group(backend="gloo")
|
||||
|
||||
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
||||
# in symm_mem and custom_all_reduce for benchmark
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
# Initialize benchmark
|
||||
benchmark = CommunicatorBenchmark(
|
||||
rank, world_size, device, cpu_group, args.sequence_lengths
|
||||
)
|
||||
|
||||
# Run benchmarks
|
||||
all_results = {}
|
||||
|
||||
for seq_len in args.sequence_lengths:
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
||||
seq_len,
|
||||
seq_len,
|
||||
HIDDEN_SIZE,
|
||||
)
|
||||
|
||||
results = benchmark.benchmark_allreduce(
|
||||
sequence_length=seq_len,
|
||||
num_warmup=args.num_warmup,
|
||||
num_trials=args.num_trials,
|
||||
)
|
||||
|
||||
all_results[seq_len] = results
|
||||
|
||||
# Synchronize between ranks
|
||||
dist.barrier()
|
||||
|
||||
# Print results (only rank 0)
|
||||
if rank == 0:
|
||||
print_results(all_results, args.sequence_lengths, world_size)
|
||||
|
||||
# Save to JSON if requested
|
||||
if args.output_json:
|
||||
# Add speedup information to results
|
||||
enhanced_results = {}
|
||||
for seq_len, comm_results in all_results.items():
|
||||
enhanced_results[seq_len] = {
|
||||
"timings": comm_results,
|
||||
"speedup_info": _calculate_speedup_info(comm_results),
|
||||
}
|
||||
|
||||
output_data = {
|
||||
"world_size": world_size,
|
||||
"dtype": str(BENCHMARK_DTYPE),
|
||||
"hidden_size": HIDDEN_SIZE,
|
||||
"sequence_lengths": args.sequence_lengths,
|
||||
"num_warmup": args.num_warmup,
|
||||
"num_trials": args.num_trials,
|
||||
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
||||
"results": enhanced_results,
|
||||
}
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
logger.info("Results saved to %s", args.output_json)
|
||||
|
||||
# Cleanup
|
||||
if cpu_group != dist.group.WORLD:
|
||||
dist.destroy_process_group(cpu_group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user