v1.0
This commit is contained in:
854
model_executor/layers/batch_invariant.py
Normal file
854
model_executor/layers/batch_invariant.py
Normal file
@@ -0,0 +1,854 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _matmul_launch_metadata(
|
||||
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = (
|
||||
f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]"
|
||||
)
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||
def matmul_kernel_persistent(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr, #
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, #
|
||||
BLOCK_SIZE_N: tl.constexpr, #
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
NUM_SMS: tl.constexpr, #
|
||||
A_LARGE: tl.constexpr,
|
||||
B_LARGE: tl.constexpr,
|
||||
C_LARGE: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
if A_LARGE:
|
||||
offs_am = offs_am.to(tl.int64)
|
||||
if B_LARGE:
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
a = tl.load(
|
||||
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||
)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
if HAS_BIAS:
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def matmul_persistent(
|
||||
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
||||
):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||
assert bias is None or bias.dim() == 1, (
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this"
|
||||
)
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
dtype = a.dtype
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
def grid(META):
|
||||
return (
|
||||
min(
|
||||
NUM_SMS,
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
),
|
||||
)
|
||||
|
||||
configs = {
|
||||
torch.bfloat16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float32: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
# print(a.device, b.device, c.device)
|
||||
matmul_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1), #
|
||||
b.stride(0),
|
||||
b.stride(1), #
|
||||
c.stride(0),
|
||||
c.stride(1), #
|
||||
NUM_SMS=NUM_SMS, #
|
||||
A_LARGE=a.numel() > 2**31,
|
||||
B_LARGE=b.numel() > 2**31,
|
||||
C_LARGE=c.numel() > 2**31,
|
||||
HAS_BIAS=bias is not None,
|
||||
**configs[dtype],
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _log_softmax_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute log_softmax along the last dimension of a 2D tensor.
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
# Get the row index for this block
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
|
||||
# Compute base pointers for input and output rows
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Find maximum value in the row for numerical stability
|
||||
max_val = -float("inf")
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||
|
||||
# Update maximum
|
||||
max_val = tl.max(tl.maximum(vals, max_val))
|
||||
|
||||
# Step 2: Compute sum of exp(x - max_val)
|
||||
sum_exp = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
|
||||
# Compute exp(x - max_val) and accumulate
|
||||
exp_vals = tl.exp(vals - max_val)
|
||||
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||
|
||||
# Compute log(sum_exp)
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||
|
||||
# Compute log_softmax
|
||||
output = vals - max_val - log_sum_exp
|
||||
|
||||
# Store results
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Compute log_softmax using Triton kernel.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Dimension along which to compute log_softmax
|
||||
(only -1 or last dim supported)
|
||||
>> Stashed changes
|
||||
Returns:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
if dim != -1 and dim != input.ndim - 1:
|
||||
raise ValueError(
|
||||
"This implementation only supports log_softmax along the last dimension"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
# Allocate output tensor
|
||||
output = torch.empty_like(input_2d)
|
||||
|
||||
# Choose block size based on the number of columns
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Launch kernel with one block per row
|
||||
grid = (n_rows,)
|
||||
_log_softmax_kernel[grid](
|
||||
input_2d,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# Reshape output back to original shape
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = (
|
||||
m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||
)
|
||||
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(
|
||||
input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype
|
||||
(or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||
)
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input.ndim
|
||||
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = input.dtype
|
||||
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input.dtype != dtype:
|
||||
input = input.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input.shape)
|
||||
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K,)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def matmul_batch_invariant(a, b, *, out=None):
|
||||
# torch.matmul can handle various dimensions
|
||||
# For 2D x 2D, it's the same as mm
|
||||
if a.ndim == 2 and b.ndim == 2:
|
||||
result = matmul_persistent(a, b)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 3 and b.ndim == 3:
|
||||
# Handle batched case like bmm
|
||||
return bmm_batch_invariant(a, b, out=out)
|
||||
elif a.ndim == 3 and b.ndim == 2:
|
||||
# Handle 3D x 2D: common for linear layers
|
||||
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
|
||||
# Reshape to 2D, do mm, reshape back
|
||||
batch, seq, hidden = a.shape
|
||||
a_2d = a.reshape(-1, hidden)
|
||||
result_2d = matmul_persistent(a_2d, b)
|
||||
result = result_2d.reshape(batch, seq, -1)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 2 and b.ndim == 3:
|
||||
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
|
||||
# By broadcasting `a` to 3D, we can reuse the batched matrix
|
||||
# multiplication logic.
|
||||
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
|
||||
return bmm_batch_invariant(a_expanded, b, out=out)
|
||||
elif a.ndim == 4 and b.ndim == 4:
|
||||
# Handle 4D attention tensors: [batch, heads, seq, dim]
|
||||
# Reshape to 3D, process, reshape back
|
||||
batch, heads, seq_a, dim_a = a.shape
|
||||
_, _, dim_b, seq_b = b.shape
|
||||
|
||||
# Reshape to [batch*heads, seq_a, dim_a]
|
||||
a_3d = a.reshape(batch * heads, seq_a, dim_a)
|
||||
b_3d = b.reshape(batch * heads, dim_b, seq_b)
|
||||
|
||||
# Do batched matmul
|
||||
result_3d = bmm_batch_invariant(a_3d, b_3d)
|
||||
|
||||
# Reshape back to [batch, heads, seq_a, seq_b]
|
||||
result = result_3d.reshape(batch, heads, seq_a, seq_b)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def bmm_batch_invariant(a, b, *, out=None):
|
||||
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||||
# Process each batch separately with our persistent kernel
|
||||
if a.ndim == 3 and b.ndim == 3:
|
||||
results = []
|
||||
for i in range(a.shape[0]):
|
||||
results.append(matmul_persistent(a[i], b[i]))
|
||||
result = torch.stack(results, dim=0)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
|
||||
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
assert not _half_to_float, "not implemented"
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def softmax_batch_invariant(input, dim, dtype=None):
|
||||
# Compute softmax in a deterministic way
|
||||
# First subtract max for numerical stability (standard practice)
|
||||
input_max = torch.amax(input, dim=dim, keepdim=True)
|
||||
input = input - input_max
|
||||
exp_x = torch.exp(input)
|
||||
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
|
||||
return exp_x / sum_exp_x
|
||||
|
||||
|
||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
if len(dim) == 0:
|
||||
dim = [i for i in range(len(input.shape))]
|
||||
|
||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||
# during iterative reduction.
|
||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||
|
||||
# Iteratively apply a deterministic mean.
|
||||
for d in sorted_dims:
|
||||
result = mean_dim(result, dim=d, keepdim=True)
|
||||
|
||||
if not keepdim:
|
||||
# Squeeze the reduced dimensions.
|
||||
for d in sorted_dims:
|
||||
result = result.squeeze(d)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
input_ptr,
|
||||
weight_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute RMS normalization along the last dimension of a 2D tensor.
|
||||
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Compute sum of squares in float32 to avoid overflow
|
||||
sum_sq = tl.zeros([1], dtype=tl.float32)
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
# Convert to float32 for accumulation to prevent overflow
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
sq_vals = vals_f32 * vals_f32
|
||||
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
||||
|
||||
# Step 2: Compute RMS (root mean square) in float32
|
||||
mean_sq = sum_sq / n_cols
|
||||
rms = tl.sqrt(mean_sq + eps)
|
||||
inv_rms = 1.0 / rms
|
||||
|
||||
# Step 3: Normalize and apply weight
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
||||
# Compute in float32 then convert back to input dtype
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
weight_f32 = weight.to(tl.float32)
|
||||
output_f32 = vals_f32 * inv_rms * weight_f32
|
||||
output = output_f32.to(vals.dtype)
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def rms_norm(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS normalization using Triton kernel.
|
||||
|
||||
RMS Norm normalizes the input by the root mean square and scales by weight:
|
||||
output = input / sqrt(mean(input^2) + eps) * weight
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
Tensor with RMS normalization applied along the last dimension
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
weight = weight.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
output = torch.empty_like(input_2d)
|
||||
BLOCK_SIZE = 1024
|
||||
grid = (n_rows,)
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
weight,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
def rms_norm_batch_invariant(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batch-invariant wrapper for RMS normalization.
|
||||
|
||||
This function provides a deterministic, batch-invariant implementation
|
||||
of RMS normalization for use with the batch_invariant mode.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
RMS normalized tensor
|
||||
"""
|
||||
return rms_norm(input, weight, eps=eps)
|
||||
|
||||
|
||||
def linear_batch_invariant(input, weight, bias=None):
|
||||
output = matmul_batch_invariant(input, weight.t())
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
_original_torch_bmm = None
|
||||
_original_fp16_reduction_precision = None
|
||||
_original_bf16_reduction_precision = None
|
||||
_original_cublas_workspace_cfg = None
|
||||
_original_cublaslt_workspace_size = None
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
|
||||
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if current_platform.is_device_capability(100):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
else:
|
||||
# Only source of batch invariance for Hopper is split-k, can disable through
|
||||
# cuBLAS workspace config
|
||||
_original_cublas_workspace_cfg = os.environ.get(
|
||||
"CUBLAS_WORKSPACE_CONFIG", None
|
||||
)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
)
|
||||
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
# Also monkeypatch torch.bmm directly as a fallback
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
_original_bf16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
)
|
||||
_original_fp16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
)
|
||||
|
||||
reduced_precision_val = (
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
)
|
||||
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
||||
|
||||
|
||||
@cache
|
||||
def vllm_is_batch_invariant():
|
||||
env_key = "VLLM_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
val = os.getenv(env_key, "0")
|
||||
try:
|
||||
is_overridden = int(val) != 0
|
||||
except ValueError:
|
||||
is_overridden = False
|
||||
return is_overridden
|
||||
|
||||
|
||||
def override_envs_for_invariance():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = [
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
|
||||
warning = (
|
||||
"You are using a decode-invariant form of batch invariance. "
|
||||
"This will not be invariant between prefill and decode."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
# NCCL determinism settings
|
||||
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
|
||||
os.environ["NCCL_COLLNET_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_PROTO"] = "Simple"
|
||||
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||
os.environ["NCCL_NTHREADS"] = "1"
|
||||
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
|
||||
|
||||
# torch.compile settings
|
||||
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_is_batch_invariant():
|
||||
override_envs_for_invariance()
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
# Disable TF32 for batch invariance - it causes non-deterministic rounding
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
Reference in New Issue
Block a user