[deterministic inference] Move batch invariant pkg to sglang (#10695)
This commit is contained in:
27
python/sglang/srt/batch_invariant_ops/__init__.py
Normal file
27
python/sglang/srt/batch_invariant_ops/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/__init__.py
|
||||||
|
|
||||||
|
from .batch_invariant_ops import (
|
||||||
|
AttentionBlockSize,
|
||||||
|
disable_batch_invariant_mode,
|
||||||
|
enable_batch_invariant_mode,
|
||||||
|
get_batch_invariant_attention_block_size,
|
||||||
|
is_batch_invariant_mode_enabled,
|
||||||
|
log_softmax,
|
||||||
|
matmul_persistent,
|
||||||
|
mean_dim,
|
||||||
|
set_batch_invariant_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_batch_invariant_mode",
|
||||||
|
"is_batch_invariant_mode_enabled",
|
||||||
|
"disable_batch_invariant_mode",
|
||||||
|
"enable_batch_invariant_mode",
|
||||||
|
"matmul_persistent",
|
||||||
|
"log_softmax",
|
||||||
|
"mean_dim",
|
||||||
|
"get_batch_invariant_attention_block_size",
|
||||||
|
"AttentionBlockSize",
|
||||||
|
]
|
||||||
549
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
Normal file
549
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
Normal file
@@ -0,0 +1,549 @@
|
|||||||
|
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from collections import namedtuple
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_batch_invariant_mode",
|
||||||
|
"is_batch_invariant_mode_enabled",
|
||||||
|
"disable_batch_invariant_mode",
|
||||||
|
"enable_batch_invariant_mode",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul_launch_metadata(
|
||||||
|
grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
ret = {}
|
||||||
|
m, n, k = args["M"], args["N"], args["K"]
|
||||||
|
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||||
|
if "tiles_per_update" in args:
|
||||||
|
ret["name"] = (
|
||||||
|
f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]"
|
||||||
|
)
|
||||||
|
if "c_ptr" in args:
|
||||||
|
bytes_per_elem = args["c_ptr"].element_size()
|
||||||
|
else:
|
||||||
|
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||||
|
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||||
|
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||||
|
group_id = tile_id // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||||
|
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||||
|
return pid_m, pid_n
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||||
|
def matmul_kernel_persistent(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr, #
|
||||||
|
bias_ptr,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K, #
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr, #
|
||||||
|
BLOCK_SIZE_N: tl.constexpr, #
|
||||||
|
BLOCK_SIZE_K: tl.constexpr, #
|
||||||
|
GROUP_SIZE_M: tl.constexpr, #
|
||||||
|
NUM_SMS: tl.constexpr, #
|
||||||
|
A_LARGE: tl.constexpr,
|
||||||
|
B_LARGE: tl.constexpr,
|
||||||
|
C_LARGE: tl.constexpr,
|
||||||
|
HAS_BIAS: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
num_tiles = num_pid_m * num_pid_n
|
||||||
|
|
||||||
|
tile_id_c = start_pid - NUM_SMS
|
||||||
|
|
||||||
|
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
|
||||||
|
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||||
|
pid_m, pid_n = _compute_pid(
|
||||||
|
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||||
|
)
|
||||||
|
start_m = pid_m * BLOCK_SIZE_M
|
||||||
|
start_n = pid_n * BLOCK_SIZE_N
|
||||||
|
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
if A_LARGE:
|
||||||
|
offs_am = offs_am.to(tl.int64)
|
||||||
|
if B_LARGE:
|
||||||
|
offs_bn = offs_bn.to(tl.int64)
|
||||||
|
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||||
|
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||||
|
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||||
|
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||||
|
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
for ki in range(k_tiles):
|
||||||
|
if A_LARGE or B_LARGE:
|
||||||
|
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||||
|
else:
|
||||||
|
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (
|
||||||
|
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
)
|
||||||
|
b_ptrs = b_ptr + (
|
||||||
|
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||||
|
)
|
||||||
|
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||||
|
)
|
||||||
|
b = tl.load(
|
||||||
|
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||||
|
)
|
||||||
|
accumulator = tl.dot(a, b, accumulator)
|
||||||
|
|
||||||
|
tile_id_c += NUM_SMS
|
||||||
|
pid_m, pid_n = _compute_pid(
|
||||||
|
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||||
|
)
|
||||||
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
if C_LARGE:
|
||||||
|
offs_cm = offs_cm.to(tl.int64)
|
||||||
|
offs_cn = offs_cn.to(tl.int64)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
if HAS_BIAS:
|
||||||
|
bias_ptrs = bias_ptr + offs_cn
|
||||||
|
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||||
|
accumulator += bias
|
||||||
|
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||||
|
c = accumulator.to(tl.float8e4nv)
|
||||||
|
else:
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_persistent(
|
||||||
|
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
):
|
||||||
|
# Check constraints.
|
||||||
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||||
|
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||||
|
assert (
|
||||||
|
bias is None or bias.dim() == 1
|
||||||
|
), "Currently assuming bias is 1D, let Horace know if you run into this"
|
||||||
|
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||||
|
M, K = a.shape
|
||||||
|
K, N = b.shape
|
||||||
|
dtype = a.dtype
|
||||||
|
# Allocates output.
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
||||||
|
|
||||||
|
# 1D launch kernel where each block gets its own program.
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
min(
|
||||||
|
NUM_SMS,
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
torch.bfloat16: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
torch.float16: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
torch.float32: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# print(a.device, b.device, c.device)
|
||||||
|
matmul_kernel_persistent[grid](
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c, #
|
||||||
|
bias,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K, #
|
||||||
|
a.stride(0),
|
||||||
|
a.stride(1), #
|
||||||
|
b.stride(0),
|
||||||
|
b.stride(1), #
|
||||||
|
c.stride(0),
|
||||||
|
c.stride(1), #
|
||||||
|
NUM_SMS=NUM_SMS, #
|
||||||
|
A_LARGE=a.numel() > 2**31,
|
||||||
|
B_LARGE=b.numel() > 2**31,
|
||||||
|
C_LARGE=c.numel() > 2**31,
|
||||||
|
HAS_BIAS=bias is not None,
|
||||||
|
**configs[dtype],
|
||||||
|
)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _log_softmax_kernel(
|
||||||
|
input_ptr,
|
||||||
|
output_ptr,
|
||||||
|
input_row_stride,
|
||||||
|
output_row_stride,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute log_softmax along the last dimension of a 2D tensor.
|
||||||
|
Each block handles one row of the input tensor.
|
||||||
|
"""
|
||||||
|
# Get the row index for this block
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
|
||||||
|
# Compute base pointers for input and output rows
|
||||||
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||||
|
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||||
|
|
||||||
|
# Step 1: Find maximum value in the row for numerical stability
|
||||||
|
max_val = -float("inf")
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||||
|
|
||||||
|
# Update maximum
|
||||||
|
max_val = tl.max(tl.maximum(vals, max_val))
|
||||||
|
|
||||||
|
# Step 2: Compute sum of exp(x - max_val)
|
||||||
|
sum_exp = 0.0
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
# Compute exp(x - max_val) and accumulate
|
||||||
|
exp_vals = tl.exp(vals - max_val)
|
||||||
|
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||||
|
|
||||||
|
# Compute log(sum_exp)
|
||||||
|
log_sum_exp = tl.log(sum_exp)
|
||||||
|
|
||||||
|
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||||
|
|
||||||
|
# Compute log_softmax
|
||||||
|
output = vals - max_val - log_sum_exp
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute log_softmax using Triton kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor
|
||||||
|
dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
|
||||||
|
>> Stashed changes
|
||||||
|
Returns:
|
||||||
|
Tensor with log_softmax applied along the specified dimension
|
||||||
|
"""
|
||||||
|
if dim != -1 and dim != input.ndim - 1:
|
||||||
|
raise ValueError(
|
||||||
|
"This implementation only supports log_softmax along the last dimension"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten all dimensions except the last one
|
||||||
|
original_shape = input.shape
|
||||||
|
input_2d = input.reshape(-1, input.shape[-1])
|
||||||
|
input_2d = input_2d.contiguous()
|
||||||
|
|
||||||
|
n_rows, n_cols = input_2d.shape
|
||||||
|
|
||||||
|
# Allocate output tensor
|
||||||
|
output = torch.empty_like(input_2d)
|
||||||
|
|
||||||
|
# Choose block size based on the number of columns
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
# Launch kernel with one block per row
|
||||||
|
grid = (n_rows,)
|
||||||
|
_log_softmax_kernel[grid](
|
||||||
|
input_2d,
|
||||||
|
output,
|
||||||
|
input_2d.stride(0),
|
||||||
|
output.stride(0),
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
# Reshape output back to original shape
|
||||||
|
return output.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def mean_kernel(
|
||||||
|
input_ptr,
|
||||||
|
output_ptr,
|
||||||
|
input_stride0,
|
||||||
|
input_stride1,
|
||||||
|
input_stride2,
|
||||||
|
output_stride0,
|
||||||
|
output_stride1,
|
||||||
|
M, # size before reduction dim
|
||||||
|
N, # size of reduction dim
|
||||||
|
K, # size after reduction dim
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Kernel for computing mean along a single dimension.
|
||||||
|
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||||
|
"""
|
||||||
|
# Program ID gives us which output element we're computing
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
# Compute output indices
|
||||||
|
m_idx = pid // K
|
||||||
|
k_idx = pid % K
|
||||||
|
|
||||||
|
# Bounds check
|
||||||
|
if m_idx >= M or k_idx >= K:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accumulate sum across reduction dimension
|
||||||
|
acc = 0.0
|
||||||
|
for n_start in range(0, N, BLOCK_SIZE):
|
||||||
|
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = n_offsets < N
|
||||||
|
|
||||||
|
# Calculate input indices
|
||||||
|
input_idx = (
|
||||||
|
m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load and accumulate
|
||||||
|
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||||
|
acc += tl.sum(vals)
|
||||||
|
|
||||||
|
# Compute mean and store
|
||||||
|
mean_val = acc / N
|
||||||
|
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||||
|
tl.store(output_ptr + output_idx, mean_val)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_dim(
|
||||||
|
input: torch.Tensor,
|
||||||
|
dim: int,
|
||||||
|
keepdim: bool = False,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Triton implementation of torch.mean with single dimension reduction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor
|
||||||
|
dim: Single dimension along which to compute mean
|
||||||
|
keepdim: Whether to keep the reduced dimension
|
||||||
|
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with mean values along specified dimension
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||||
|
assert (
|
||||||
|
-input.ndim <= dim < input.ndim
|
||||||
|
), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||||
|
|
||||||
|
# Handle negative dim
|
||||||
|
if dim < 0:
|
||||||
|
dim = dim + input.ndim
|
||||||
|
|
||||||
|
# Handle dtype
|
||||||
|
if dtype is None:
|
||||||
|
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = input.dtype
|
||||||
|
|
||||||
|
# Convert input to appropriate dtype if needed
|
||||||
|
if input.dtype != dtype:
|
||||||
|
input = input.to(dtype)
|
||||||
|
|
||||||
|
# Get input shape and strides
|
||||||
|
shape = list(input.shape)
|
||||||
|
|
||||||
|
# Calculate dimensions for kernel
|
||||||
|
M = 1
|
||||||
|
for i in range(dim):
|
||||||
|
M *= shape[i]
|
||||||
|
|
||||||
|
N = shape[dim]
|
||||||
|
|
||||||
|
K = 1
|
||||||
|
for i in range(dim + 1, len(shape)):
|
||||||
|
K *= shape[i]
|
||||||
|
|
||||||
|
# Reshape input to 3D view (M, N, K)
|
||||||
|
input_3d = input.reshape(M, N, K)
|
||||||
|
|
||||||
|
# Create output shape
|
||||||
|
if keepdim:
|
||||||
|
output_shape = shape.copy()
|
||||||
|
output_shape[dim] = 1
|
||||||
|
else:
|
||||||
|
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||||
|
|
||||||
|
# Create output tensor
|
||||||
|
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||||
|
|
||||||
|
# Reshape output for kernel
|
||||||
|
if keepdim:
|
||||||
|
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||||
|
else:
|
||||||
|
output_2d = output.reshape(M, K)
|
||||||
|
|
||||||
|
# Launch kernel
|
||||||
|
grid = (M * K,)
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
mean_kernel[grid](
|
||||||
|
input_3d,
|
||||||
|
output_2d,
|
||||||
|
input_3d.stride(0),
|
||||||
|
input_3d.stride(1),
|
||||||
|
input_3d.stride(2),
|
||||||
|
output_2d.stride(0),
|
||||||
|
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def mm_batch_invariant(a, b):
|
||||||
|
return matmul_persistent(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def addmm_batch_invariant(bias, a, b):
|
||||||
|
return matmul_persistent(a, b, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||||
|
assert not _half_to_float, "not implemented"
|
||||||
|
return log_softmax(input, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||||
|
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||||
|
if len(dim) == 1:
|
||||||
|
return mean_dim(input, dim[0], keepdim=keepdim)
|
||||||
|
else:
|
||||||
|
assert input.dtype in {
|
||||||
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float32,
|
||||||
|
}, "only float types supported for now"
|
||||||
|
n_elems = 1
|
||||||
|
for d in dim:
|
||||||
|
n_elems *= input.shape[d]
|
||||||
|
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
||||||
|
|
||||||
|
|
||||||
|
_batch_invariant_MODE = False
|
||||||
|
_batch_invariant_LIB = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_batch_invariant_mode_enabled():
|
||||||
|
return _batch_invariant_MODE
|
||||||
|
|
||||||
|
|
||||||
|
def enable_batch_invariant_mode():
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
if _batch_invariant_MODE:
|
||||||
|
return
|
||||||
|
|
||||||
|
_batch_invariant_MODE = True
|
||||||
|
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||||
|
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl(
|
||||||
|
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||||
|
)
|
||||||
|
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||||
|
|
||||||
|
|
||||||
|
def disable_batch_invariant_mode():
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
if _batch_invariant_LIB is not None:
|
||||||
|
_batch_invariant_LIB._destroy()
|
||||||
|
_batch_invariant_MODE = False
|
||||||
|
_batch_invariant_LIB = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_batch_invariant_mode(enabled: bool = True):
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
|
||||||
|
if enabled:
|
||||||
|
enable_batch_invariant_mode()
|
||||||
|
else:
|
||||||
|
disable_batch_invariant_mode()
|
||||||
|
yield
|
||||||
|
if _batch_invariant_LIB is not None:
|
||||||
|
_batch_invariant_LIB._destroy()
|
||||||
|
_batch_invariant_MODE, _batch_invariant_LIB = old_data
|
||||||
|
|
||||||
|
|
||||||
|
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||||
|
return AttentionBlockSize(block_m=16, block_n=16)
|
||||||
@@ -408,7 +408,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Enable batch invariant mode
|
# Enable batch invariant mode
|
||||||
if server_args.enable_deterministic_inference:
|
if server_args.enable_deterministic_inference:
|
||||||
from batch_invariant_ops import enable_batch_invariant_mode
|
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
|
||||||
|
|
||||||
enable_batch_invariant_mode()
|
enable_batch_invariant_mode()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user