158 lines
5.0 KiB
Python
158 lines
5.0 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
def seeded_uniform(
|
|
*size,
|
|
seeds: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[torch.device, str]] = None,
|
|
pin_memory: Optional[bool] = False,
|
|
) -> torch.Tensor:
|
|
"""Similar to torch.rand, but allows for seeds to be set per row.
|
|
|
|
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
|
|
If it is 3d, the additional seeds needed will be derived automatically
|
|
in a deterministic fashion:
|
|
[
|
|
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
|
|
]
|
|
"""
|
|
n_dims = len(size)
|
|
|
|
if n_dims > 3:
|
|
raise ValueError("seeded_uniform only supports up to 3D tensors")
|
|
|
|
if out is None:
|
|
out = torch.empty(*size,
|
|
dtype=dtype,
|
|
device=device,
|
|
pin_memory=pin_memory)
|
|
elif out.shape != size:
|
|
raise ValueError("shape of out and size must be the same")
|
|
|
|
if n_dims == 3:
|
|
n_rows, n_3d, n_cols = out.shape
|
|
stride_row = out.stride(0)
|
|
stride_3d = out.stride(1)
|
|
elif n_dims == 2:
|
|
n_rows, n_cols = out.shape
|
|
n_3d = 1
|
|
stride_row = out.stride(0)
|
|
stride_3d = 1
|
|
else:
|
|
n_cols = out.shape[0]
|
|
n_rows = 1
|
|
n_3d = 1
|
|
stride_row = 1
|
|
stride_3d = 1
|
|
|
|
if seeds.ndim != 1:
|
|
raise ValueError("seeds must be a 1D tensor")
|
|
|
|
if seeds.numel() != n_rows:
|
|
raise ValueError(
|
|
"seeds must have the same number of elements as out has rows")
|
|
|
|
# The philox PRNG Triton uses generates 4 random numbers at once.
|
|
# Therefore, the most efficient use of it is to divide the
|
|
# block size by 4, and then save the generated random numbers to
|
|
# each of the 4 slices of the tensor.
|
|
full_block_size = triton.next_power_of_2(n_cols)
|
|
philox_block_size = max(full_block_size // 4, 1)
|
|
n_slices = full_block_size // philox_block_size
|
|
num_warps = 4
|
|
# Manual tuning. This seems to give best performance on A100 for
|
|
# simple kernels like this.
|
|
if philox_block_size >= 8192:
|
|
num_warps = 32
|
|
elif philox_block_size >= 4096:
|
|
num_warps = 16
|
|
elif philox_block_size >= 2048:
|
|
num_warps = 8
|
|
|
|
_seeded_uniform_triton[(n_rows, n_3d)](
|
|
out,
|
|
seeds,
|
|
stride_row,
|
|
stride_3d,
|
|
seeds.stride(0),
|
|
n_rows,
|
|
n_3d,
|
|
n_cols,
|
|
n_slices=n_slices,
|
|
num_warps=num_warps,
|
|
block_size=philox_block_size,
|
|
)
|
|
return out
|
|
|
|
|
|
@triton.jit
|
|
def _seeded_uniform_triton(
|
|
out_ptr: torch.Tensor,
|
|
seed_ptr: torch.Tensor,
|
|
out_row_stride: int,
|
|
out_3d_stride: int,
|
|
seed_row_stride: int,
|
|
n_rows: int,
|
|
n_3d: int,
|
|
n_cols: int,
|
|
n_slices: tl.constexpr,
|
|
block_size: tl.constexpr,
|
|
):
|
|
"""
|
|
Generate a random float32 number in [0, 1) for each element in the output
|
|
tensor. The random numbers in a row generated using the seed for that row.
|
|
|
|
Args:
|
|
out_ptr: The output tensor.
|
|
seed_ptr: The per-row seeds to use for random number generation.
|
|
out_row_stride: The stride between rows of the output tensor.
|
|
out_3d_stride: The stride between 3D slices of the output tensor.
|
|
seed_row_stride: The stride between rows of the seed tensor.
|
|
n_rows: The number of rows in the output tensor.
|
|
n_3d: The size of second dimension of the output tensor,
|
|
if output tensor is 3D.
|
|
n_cols: The number of columns in the output tensor.
|
|
n_slices: The number of philox outputs to use.
|
|
"""
|
|
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
|
|
|
|
# Get the row index.
|
|
row_idx = tl.program_id(axis=0)
|
|
three_d_idx = tl.program_id(axis=1)
|
|
|
|
philox_offsets = tl.arange(0, block_size)
|
|
# Get the seed for the current element.
|
|
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
|
|
if three_d_idx > 0:
|
|
seed ^= three_d_idx
|
|
# Generate random numbers in [0, 1).
|
|
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
|
|
|
|
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
|
|
three_d_idx * out_3d_stride)
|
|
out1_offsets = philox_offsets
|
|
tl.store(output_row_start_ptr + out1_offsets,
|
|
out1,
|
|
mask=out1_offsets < n_cols)
|
|
if n_slices > 1:
|
|
out2_offsets = tl.arange(block_size, block_size * 2)
|
|
tl.store(output_row_start_ptr + out2_offsets,
|
|
out2,
|
|
mask=out2_offsets < n_cols)
|
|
if n_slices > 2:
|
|
out3_offsets = tl.arange(block_size * 2, block_size * 3)
|
|
tl.store(output_row_start_ptr + out3_offsets,
|
|
out3,
|
|
mask=out3_offsets < n_cols)
|
|
if n_slices > 3:
|
|
out4_offsets = tl.arange(block_size * 3, block_size * 4)
|
|
tl.store(output_row_start_ptr + out4_offsets,
|
|
out4,
|
|
mask=out4_offsets < n_cols)
|