diff --git a/benchmark/fbgemm/README.md b/benchmark/kernels/fbgemm/README.md similarity index 100% rename from benchmark/fbgemm/README.md rename to benchmark/kernels/fbgemm/README.md diff --git a/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py b/benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py similarity index 100% rename from benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py rename to benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py diff --git a/benchmark/kernels/flashinfer_allreduce_fusion/README.md b/benchmark/kernels/flashinfer_allreduce_fusion/README.md new file mode 100644 index 000000000..e651604c7 --- /dev/null +++ b/benchmark/kernels/flashinfer_allreduce_fusion/README.md @@ -0,0 +1,102 @@ +# FlashInfer Fused AllReduce + RMSNorm Benchmark + +This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations. + +This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators. + +Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py` + +## Feature Overview + +- Compare average execution time (ms) and calculate speedup ratios for the following paths: + - standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm) + - flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes + - Optionally compare FP8/FP4 quantized fused paths with standard paths +- Use CUDA Graph capture and batch replay to reduce measurement noise +- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation +- Optionally export results in Markdown format + +## Runtime Environment and Prerequisites + +- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend) +- Properly install/compile sglang along with sgl-kernel and custom operators + +## Quick Start (Command Examples) + +The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine: + +- Regular paths only (no quantization): +``` +torchrun --nproc_per_node=2 \ +benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \ +--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100 +``` + +- FP8 quantization paths only: +``` +torchrun --nproc_per_node=2 \ +benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \ +--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100 +``` + +- FP4 quantization paths only: +``` +torchrun --nproc_per_node=2 \ +benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \ +--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100 +``` + +- Larger hidden dimensions: +``` +torchrun --nproc_per_node=2 \ +benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \ +--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100 +``` + +## Parameter Description +- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048) +- `--hidden-dim`: Hidden dimension (default: 8192) +- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16) +- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual") +- Mutually exclusive quantization options: + - `--no-quant`: No quantization testing + - `--quant-fp8`: Only FP8 quantization testing + - `--quant-fp4`: Only FP4 quantization testing + - `--quant-all`: Test all (default) +- FlashInfer related: + - `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously) +- Runtime configuration: + - `--warmup`: Warmup count before graph capture and before graph replay (default 5) + - `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times) + - `--output-file`: Save results as Markdown file (only rank0 takes effect) + +## Output Example + +Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example: +``` +================================================================================ +Results: seq_len=1024, hidden_dim=1024 +dtype=torch.bfloat16, residual=yes, quant_mode=none +================================================================================ +Operation Time (ms) Speedup +-------------------------------------------------------------------------------- +standard_allreduce_rmsnorm 0.024 0.98x +standard_allreduce_rmsnorm_native_compiled 0.023 baseline +flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x +flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x +``` + +If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file. + +## Important Notes and Recommendations + +- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device. +- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt. +- FlashInfer: + - If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs. + - The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously. +- FP8/FP4: + - FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc. + - FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support. +- CUDA Graph: + - Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter. diff --git a/benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py b/benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py new file mode 100644 index 000000000..4aebf62b9 --- /dev/null +++ b/benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py @@ -0,0 +1,1304 @@ +# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 + + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 + torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 +""" + +import argparse +import contextlib +import itertools +import logging +import os +import time +from typing import Optional + +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce +from sglang.srt.distributed.parallel_state import ( + cleanup_dist_env_and_memory, + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.layers.layernorm import RMSNorm # noqa +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE +from sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8 + +try: + from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM + from sgl_kernel import rmsnorm as SGL_RMS_NORM + from sgl_kernel import scaled_fp4_quant as SGL_SCALED_FP4_QUANT +except Exception: # pragma: no cover - fallback on non-supported platforms + SGL_FUSED_ADD_RMS_NORM = None + SGL_RMS_NORM = None + SGL_SCALED_FP4_QUANT = None + +FP8_DTYPE = SGLANG_FP8_DTYPE + +logger = logging.getLogger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=None, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=None, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=None, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def standard_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Then RMS norm + if residual is not None: + # Fused add + RMS norm (in-place on allreduce_out) + if SGL_FUSED_ADD_RMS_NORM is not None: + SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + rms.forward_native(allreduce_out, residual) + else: + # Just RMS norm + if SGL_RMS_NORM is not None: + _ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + _ = rms.forward_native(allreduce_out) + + +def standard_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + static FP8 quantization + if residual is not None: + if SGL_FUSED_ADD_RMS_NORM is not None: + SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) + quant_out, _ = static_quant_fp8( + allreduce_out, scale_factor, repeat_scale=False + ) + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + normed, _ = rms.forward_native(allreduce_out, residual) + quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False) + return quant_out, residual + else: + if SGL_RMS_NORM is not None: + normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + normed = rms.forward_native(allreduce_out) + quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False) + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization.""" + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + if residual is not None: + if SGL_FUSED_ADD_RMS_NORM is not None: + SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) + quant_input = allreduce_out + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + quant_input, _ = rms.forward_native(allreduce_out, residual) + residual_out = residual + else: + if SGL_RMS_NORM is not None: + quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) + else: + rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) + rms.weight.data = rms_gamma + quant_input = rms.forward_native(allreduce_out) + residual_out = allreduce_out + + # Finally FP4 quantization + if SGL_SCALED_FP4_QUANT is None: + raise RuntimeError("scaled_fp4_quant is not available on this platform") + quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale) + if residual is not None: + return quant_res, residual_out, output_scale_res + else: + return quant_res, quant_input + + +def standard_allreduce_rmsnorm_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Apply native RMSNorm + if residual is not None: + result = rmsnorm_layer.forward_native(allreduce_out, residual) + return result # Returns (norm_out, residual_out) + else: + result = rmsnorm_layer.forward_native(allreduce_out) + return result # Returns norm_out + + +def standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + residual_out = allreduce_out + + # Apply native FP8 quantization + quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False) + + if residual is not None: + return quant_out, residual_out + else: + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + quant_input = norm_out + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + quant_input = norm_out + residual_out = allreduce_out + + # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) + if SGL_SCALED_FP4_QUANT is None: + raise RuntimeError("scaled_fp4_quant is not available on this platform") + quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale) + + if residual is not None: + return quant_res, residual_out, output_scale_res + else: + return quant_res, norm_out + + +# Compiled versions of native functions +@torch.compile +def standard_allreduce_rmsnorm_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm.""" + return standard_allreduce_rmsnorm_native( + input_tensor, residual, rmsnorm_layer, norm_out + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp8_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" + return standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor, + residual, + rmsnorm_layer, + scale_factor, + norm_out, + quant_out, + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp4_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" + return standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor, + residual, + rmsnorm_layer, + input_global_scale, + quant_out, + output_scale, + norm_out, + ) + + +def create_test_tensors( + seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe + with graph_capture() as graph_capture_context: + with torch.cuda.graph(graph, stream=graph_capture_context.stream): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + seq_len: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: Optional[FlashInferFusedAllReduceParams], + quant_mode: str = "all", + disable_oneshot: bool = False, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + + # Create RMSNorm once for native benchmarks + rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps) + rmsnorm_layer.weight.data = rms_gamma + + if quant_mode in ["all", "none"]: + # Standard AllReduce + RMSNorm + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + results["standard_allreduce_rmsnorm"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results["standard_allreduce_rmsnorm"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") + + if quant_mode in ["all", "fp8_only"]: + # Standard AllReduce + RMSNorm + FP8 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + # quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function + scale_factor=scale_fp8, + norm_out=norm_out, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( + "inf" + ) + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + "inf" + ) + + if quant_mode in ["all", "fp4_only"]: + # Standard AllReduce + RMSNorm + FP4 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_mode={quant_mode}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results + +**World Size:** {world_size} +**Hidden Dimension:** {args.hidden_dim} +**Warmup Iterations:** {args.warmup} +**Benchmark Trials:** {args.trials} +**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} + +--- + +""" + + for result in all_results: + seq_len = result["seq_len"] + dtype = result["dtype"] + use_residual = result["use_residual"] + results_dict = result["results"] + + residual_str = "with residual" if use_residual else "no residual" + + markdown += f""" +## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} + +| Operation | Time (ms) | Speedup | +|-----------|-----------|---------| +""" + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + # Format operation name for better readability + formatted_op_name = result["operation"].replace("_", " ").title() + markdown += f"| {formatted_op_name} | {result['time_str']} |" + markdown += f"{result['speedup_str']} |\n" + + markdown += "\n" + + return markdown + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "w") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Sequence lengths to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + # Quantization mode options (mutually exclusive with --no-quant) + quant_group = parser.add_mutually_exclusive_group() + quant_group.add_argument( + "--no-quant", action="store_true", help="Skip all quantization tests" + ) + quant_group.add_argument( + "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" + ) + quant_group.add_argument( + "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" + ) + quant_group.add_argument( + "--quant-all", + action="store_true", + help="Run all quantization tests (default)", + ) + + parser.add_argument( + "--disable-oneshot", + action="store_true", + help="Disable oneshot mode for FlashInfer operations", + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment( + world_size=world_size, + rank=rank, + local_rank=rank, + backend="nccl", + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Determine quantization mode + if args.no_quant: + quant_mode = "none" + elif args.quant_fp8: + quant_mode = "fp8_only" + elif args.quant_fp4: + quant_mode = "fp4_only" + else: # args.quant_all or default + quant_mode = "all" + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization mode: %s", quant_mode) + if flashinfer_comm is not None: + oneshot_status = "enabled" if not args.disable_oneshot else "disabled" + logger.info( + "FlashInfer available - will benchmark fused operations (oneshot: %s)", + oneshot_status, + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + if not args.no_residual: + residual_options.append(False) + + configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for seq_len, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", + seq_len, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + seq_len, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_mode=quant_mode, + disable_oneshot=args.disable_oneshot, + ) + + # Store results for markdown export + if rank == 0: + all_results.append( + { + "seq_len": seq_len, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_mode": quant_mode, + "results": results, + } + ) + + print_results( + results, + seq_len, + args.hidden_dim, + dtype, + use_residual, + quant_mode, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + with contextlib.suppress(Exception): + dist.barrier() + cleanup_dist_env_and_memory(shutdown_ray=False) + + +if __name__ == "__main__": + main()