# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os from collections.abc import Callable from functools import cache from typing import Any import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) def _matmul_launch_metadata( grid: Callable[..., Any], kernel: Any, args: dict[str, Any] ) -> dict[str, Any]: ret = {} m, n, k = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" if "tiles_per_update" in args: ret["name"] = ( f"{kernel.name} [M={m}, N={n}, K={k}, " f"tiles_per_update={args['tiles_per_update']:02}]" ) if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) return ret @triton.jit def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m return pid_m, pid_n @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent( a_ptr, b_ptr, c_ptr, # bias_ptr, M, N, K, # stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # NUM_SMS: tl.constexpr, # A_LARGE: tl.constexpr, B_LARGE: tl.constexpr, C_LARGE: tl.constexpr, HAS_BIAS: tl.constexpr, ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) if A_LARGE: offs_am = offs_am.to(tl.int64) if B_LARGE: offs_bn = offs_bn.to(tl.int64) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): if A_LARGE or B_LARGE: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) else: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak ) b_ptrs = b_ptr + ( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) a = tl.load( a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 ) b = tl.load( b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 ) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if HAS_BIAS: bias_ptrs = bias_ptr + offs_cn bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias c = accumulator.to(c_ptr.dtype.element_ty) tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent( a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None ): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" assert bias is None or bias.dim() == 1, ( "Currently assuming bias is 1D, let Horace know if you run into this" ) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape dtype = a.dtype # Allocates output. c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. def grid(META): return ( min( NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ), ) configs = { torch.bfloat16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float32: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, } # print(a.device, b.device, c.device) matmul_kernel_persistent[grid]( a, b, c, # bias, M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # NUM_SMS=NUM_SMS, # A_LARGE=a.numel() > 2**31, B_LARGE=b.numel() > 2**31, C_LARGE=c.numel() > 2**31, HAS_BIAS=bias is not None, **configs[dtype], ) return c @triton.jit def _log_softmax_kernel( input_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): """ Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor. """ # Get the row index for this block row_idx = tl.program_id(0).to(tl.int64) # Compute base pointers for input and output rows row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # Step 1: Find maximum value in the row for numerical stability max_val = -float("inf") for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) # Update maximum max_val = tl.max(tl.maximum(vals, max_val)) # Step 2: Compute sum of exp(x - max_val) sum_exp = 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) # Compute exp(x - max_val) and accumulate exp_vals = tl.exp(vals - max_val) sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) # Compute log(sum_exp) log_sum_exp = tl.log(sum_exp) # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask) # Compute log_softmax output = vals - max_val - log_sum_exp # Store results tl.store(output_row_start_ptr + col_idx, output, mask=mask) def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute log_softmax using Triton kernel. Args: input: Input tensor dim: Dimension along which to compute log_softmax (only -1 or last dim supported) >> Stashed changes Returns: Tensor with log_softmax applied along the specified dimension """ if dim != -1 and dim != input.ndim - 1: raise ValueError( "This implementation only supports log_softmax along the last dimension" ) # Flatten all dimensions except the last one original_shape = input.shape input_2d = input.reshape(-1, input.shape[-1]) input_2d = input_2d.contiguous() n_rows, n_cols = input_2d.shape # Allocate output tensor output = torch.empty_like(input_2d) # Choose block size based on the number of columns BLOCK_SIZE = 1024 # Launch kernel with one block per row grid = (n_rows,) _log_softmax_kernel[grid]( input_2d, output, input_2d.stride(0), output.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, ) # Reshape output back to original shape return output.reshape(original_shape) @triton.jit def mean_kernel( input_ptr, output_ptr, input_stride0, input_stride1, input_stride2, output_stride0, output_stride1, M, # size before reduction dim N, # size of reduction dim K, # size after reduction dim BLOCK_SIZE: tl.constexpr, ): """ Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced. """ # Program ID gives us which output element we're computing pid = tl.program_id(0) # Compute output indices m_idx = pid // K k_idx = pid % K # Bounds check if m_idx >= M or k_idx >= K: return # Accumulate sum across reduction dimension acc = 0.0 for n_start in range(0, N, BLOCK_SIZE): n_offsets = n_start + tl.arange(0, BLOCK_SIZE) mask = n_offsets < N # Calculate input indices input_idx = ( m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 ) # Load and accumulate vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) acc += tl.sum(vals) # Compute mean and store mean_val = acc / N output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) def mean_dim( input: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. Args: input: Input tensor dim: Single dimension along which to compute mean keepdim: Whether to keep the reduced dimension dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) Returns: Tensor with mean values along specified dimension """ # Validate inputs assert -input.ndim <= dim < input.ndim, ( f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" ) # Handle negative dim if dim < 0: dim = dim + input.ndim # Handle dtype if dtype is None: if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: dtype = torch.float32 else: dtype = input.dtype # Convert input to appropriate dtype if needed if input.dtype != dtype: input = input.to(dtype) # Get input shape and strides shape = list(input.shape) # Calculate dimensions for kernel M = 1 for i in range(dim): M *= shape[i] N = shape[dim] K = 1 for i in range(dim + 1, len(shape)): K *= shape[i] # Reshape input to 3D view (M, N, K) input_3d = input.reshape(M, N, K) # Create output shape if keepdim: output_shape = shape.copy() output_shape[dim] = 1 else: output_shape = shape[:dim] + shape[dim + 1 :] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input.device) # Reshape output for kernel output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K) # Launch kernel grid = (M * K,) BLOCK_SIZE = 1024 mean_kernel[grid]( input_3d, output_2d, input_3d.stride(0), input_3d.stride(1), input_3d.stride(2), output_2d.stride(0), output_2d.stride(1) if output_2d.ndim > 1 else 0, M, N, K, BLOCK_SIZE, ) return output def mm_batch_invariant(a, b): return matmul_persistent(a, b) def matmul_batch_invariant(a, b, *, out=None): # torch.matmul can handle various dimensions # For 2D x 2D, it's the same as mm if a.ndim == 2 and b.ndim == 2: result = matmul_persistent(a, b) if out is not None: out.copy_(result) return out return result elif a.ndim == 3 and b.ndim == 3: # Handle batched case like bmm return bmm_batch_invariant(a, b, out=out) elif a.ndim == 3 and b.ndim == 2: # Handle 3D x 2D: common for linear layers # (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out) # Reshape to 2D, do mm, reshape back batch, seq, hidden = a.shape a_2d = a.reshape(-1, hidden) result_2d = matmul_persistent(a_2d, b) result = result_2d.reshape(batch, seq, -1) if out is not None: out.copy_(result) return out return result elif a.ndim == 2 and b.ndim == 3: # Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N) # By broadcasting `a` to 3D, we can reuse the batched matrix # multiplication logic. a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1) return bmm_batch_invariant(a_expanded, b, out=out) elif a.ndim == 4 and b.ndim == 4: # Handle 4D attention tensors: [batch, heads, seq, dim] # Reshape to 3D, process, reshape back batch, heads, seq_a, dim_a = a.shape _, _, dim_b, seq_b = b.shape # Reshape to [batch*heads, seq_a, dim_a] a_3d = a.reshape(batch * heads, seq_a, dim_a) b_3d = b.reshape(batch * heads, dim_b, seq_b) # Do batched matmul result_3d = bmm_batch_invariant(a_3d, b_3d) # Reshape back to [batch, heads, seq_a, seq_b] result = result_3d.reshape(batch, heads, seq_a, seq_b) if out is not None: out.copy_(result) return out return result else: raise ValueError( f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, " f"3D x 2D, 2D x 3D, and 4D x 4D, " f"got shapes {a.shape} and {b.shape}" ) def bmm_batch_invariant(a, b, *, out=None): # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) # Process each batch separately with our persistent kernel if a.ndim == 3 and b.ndim == 3: results = [] for i in range(a.shape[0]): results.append(matmul_persistent(a[i], b[i])) result = torch.stack(results, dim=0) if out is not None: out.copy_(result) return out return result else: raise ValueError( f"bmm_batch_invariant expects 3D tensors, " f"got shapes {a.shape} and {b.shape}" ) def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) def _log_softmax_batch_invariant(input, dim, _half_to_float): assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) def softmax_batch_invariant(input, dim, dtype=None): # Compute softmax in a deterministic way # First subtract max for numerical stability (standard practice) input_max = torch.amax(input, dim=dim, keepdim=True) input = input - input_max exp_x = torch.exp(input) sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) return exp_x / sum_exp_x def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) if len(dim) == 0: dim = [i for i in range(len(input.shape))] # Sort dimensions to reduce from largest to smallest to handle shifting dims # during iterative reduction. sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) # Iteratively apply a deterministic mean. for d in sorted_dims: result = mean_dim(result, dim=d, keepdim=True) if not keepdim: # Squeeze the reduced dimensions. for d in sorted_dims: result = result.squeeze(d) return result @triton.jit def _rms_norm_kernel( input_ptr, weight_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, eps, BLOCK_SIZE: tl.constexpr, ): """ Compute RMS normalization along the last dimension of a 2D tensor. RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight Each block handles one row of the input tensor. """ row_idx = tl.program_id(0).to(tl.int64) row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # Step 1: Compute sum of squares in float32 to avoid overflow sum_sq = tl.zeros([1], dtype=tl.float32) for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) # Convert to float32 for accumulation to prevent overflow vals_f32 = vals.to(tl.float32) sq_vals = vals_f32 * vals_f32 sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) # Step 2: Compute RMS (root mean square) in float32 mean_sq = sum_sq / n_cols rms = tl.sqrt(mean_sq + eps) inv_rms = 1.0 / rms # Step 3: Normalize and apply weight for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) # Compute in float32 then convert back to input dtype vals_f32 = vals.to(tl.float32) weight_f32 = weight.to(tl.float32) output_f32 = vals_f32 * inv_rms * weight_f32 output = output_f32.to(vals.dtype) tl.store(output_row_start_ptr + col_idx, output, mask=mask) def rms_norm( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> torch.Tensor: """ Compute RMS normalization using Triton kernel. RMS Norm normalizes the input by the root mean square and scales by weight: output = input / sqrt(mean(input^2) + eps) * weight Args: input: Input tensor of shape (..., hidden_size) weight: Weight tensor of shape (hidden_size,) eps: Small constant for numerical stability Returns: Tensor with RMS normalization applied along the last dimension """ assert weight.dim() == 1, "Weight must be 1-dimensional" assert input.shape[-1] == weight.shape[0], ( f"Input last dimension ({input.shape[-1]}) must match " f"weight dimension ({weight.shape[0]})" ) # Flatten all dimensions except the last one original_shape = input.shape input_2d = input.reshape(-1, input.shape[-1]) input_2d = input_2d.contiguous() weight = weight.contiguous() n_rows, n_cols = input_2d.shape output = torch.empty_like(input_2d) BLOCK_SIZE = 1024 grid = (n_rows,) _rms_norm_kernel[grid]( input_2d, weight, output, input_2d.stride(0), output.stride(0), n_cols, eps, BLOCK_SIZE=BLOCK_SIZE, ) return output.reshape(original_shape) def rms_norm_batch_invariant( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> torch.Tensor: """ Batch-invariant wrapper for RMS normalization. This function provides a deterministic, batch-invariant implementation of RMS normalization for use with the batch_invariant mode. Args: input: Input tensor of shape (..., hidden_size) weight: Weight tensor of shape (hidden_size,) eps: Small constant for numerical stability Returns: RMS normalized tensor """ return rms_norm(input, weight, eps=eps) def linear_batch_invariant(input, weight, bias=None): output = matmul_batch_invariant(input, weight.t()) if bias is not None: output = output + bias return output _batch_invariant_MODE = False _batch_invariant_LIB = None _original_torch_bmm = None _original_fp16_reduction_precision = None _original_bf16_reduction_precision = None _original_cublas_workspace_cfg = None _original_cublaslt_workspace_size = None def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm global _original_fp16_reduction_precision, _original_bf16_reduction_precision global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size if _batch_invariant_MODE: return _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") # Batch invariant matmuls are no longer needed after cublas overrides if not is_torch_equal_or_newer("2.10.0.dev"): if current_platform.is_device_capability(100): # For PyTorch 2.9, B200 uses GEMV for bs=1 # Requires https://github.com/pytorch/pytorch/pull/166735 _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") else: # Only source of batch invariance for Hopper is split-k, can disable through # cuBLAS workspace config _original_cublas_workspace_cfg = os.environ.get( "CUBLAS_WORKSPACE_CONFIG", None ) _original_cublaslt_workspace_size = os.environ.get( "CUBLASLT_WORKSPACE_SIZE", None ) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) _batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") # Also monkeypatch torch.bmm directly as a fallback _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") _original_torch_bmm = torch.bmm torch.bmm = bmm_batch_invariant _original_bf16_reduction_precision = ( torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction ) _original_fp16_reduction_precision = ( torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction ) reduced_precision_val = ( (False, False) if is_torch_equal_or_newer("2.10.0.dev") else False ) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( reduced_precision_val ) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( reduced_precision_val ) torch.backends.cuda.preferred_blas_library(backend="cublaslt") @cache def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT" is_overridden = False val = os.getenv(env_key, "0") try: is_overridden = int(val) != 0 except ValueError: is_overridden = False return is_overridden def override_envs_for_invariance(): curr_attn_backend = envs.VLLM_ATTENTION_BACKEND supported_backends = [ "FLASH_ATTN", # best supported backend "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance ] if curr_attn_backend not in supported_backends: warning = ( "Forcibly updating attention backend to" f" {supported_backends[0]} for batch_invariant. " f" Supported backends: {supported_backends}." ) logger.warning_once(warning) os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." ) logger.warning_once(warning) os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # NCCL determinism settings os.environ["NCCL_LAUNCH_MODE"] = "GROUP" os.environ["NCCL_COLLNET_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["NCCL_P2P_NET_DISABLE"] = "1" os.environ["NCCL_MIN_NCHANNELS"] = "1" os.environ["NCCL_MAX_NCHANNELS"] = "1" os.environ["NCCL_PROTO"] = "Simple" os.environ["NCCL_ALGO"] = "allreduce:tree" os.environ["NCCL_NTHREADS"] = "1" os.environ["NCCL_SOCKET_NTHREADS"] = "1" # torch.compile settings os.environ["VLLM_USE_AOT_COMPILE"] = "0" def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_is_batch_invariant(): override_envs_for_invariance() enable_batch_invariant_mode() # Disable TF32 for batch invariance - it causes non-deterministic rounding torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False