Files
2026-01-09 13:34:11 +08:00

158 lines
5.0 KiB
Python

from typing import Optional, Union
import torch
import triton
import triton.language as tl
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)