First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,201 @@
"""isort:skip_file"""
# Import order is significant here.
from . import math
from . import extra
from .standard import (
cdiv,
sigmoid,
softmax,
ravel,
swizzle2d,
zeros,
zeros_like,
)
from .core import (
abs,
advance,
arange,
argmin,
argmax,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
broadcast,
broadcast_to,
cat,
constexpr,
cos,
debug_barrier,
device_assert,
device_print,
dot,
dtype,
exp,
expand_dims,
full,
fdiv,
float16,
float32,
float64,
float8e4,
float8e5,
function_type,
int1,
int16,
int32,
int64,
int8,
load,
log,
make_block_ptr,
max,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
pointer_type,
program_id,
reduce,
reshape,
sin,
sqrt,
static_assert,
static_print,
store,
sum,
static_range,
tensor,
trans,
triton,
uint16,
uint32,
uint64,
uint8,
umulhi,
view,
void,
where,
xor_sum,
)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint32_to_uniform_float,
)
__all__ = [
"abs",
"advance",
"arange",
"argmin",
"argmax",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"broadcast",
"broadcast_to",
"builtin",
"cat",
"cdiv",
"constexpr",
"cos",
"debug_barrier",
"device_assert",
"device_print",
"dot",
"dtype",
"exp",
"expand_dims",
"extra",
"fdiv",
"float16",
"float32",
"float64",
"float8e4",
"float8e5",
"full",
"function_type",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"math",
"load",
"log",
"make_block_ptr",
"max",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"ravel",
"reduce",
"reshape",
"sigmoid",
"sin",
"softmax",
"sqrt",
"static_range",
"static_assert",
"static_print",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint32_to_uniform_float",
"uint64",
"uint8",
"umulhi",
"view",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]

Binary file not shown.

Binary file not shown.

1729
pkgs/triton/language/core.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
from . import cuda
__all__ = ['cuda']

Binary file not shown.

View File

@@ -0,0 +1,19 @@
import os
from .. import core
__path__ = os.path.dirname(os.path.abspath(__file__))
@core.extern
def globaltimer(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("globaltimer", core.dtype("int64")),
}, is_pure=False, _builder=_builder)
@core.extern
def smid(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("smid", core.dtype("int32")),
}, is_pure=True, _builder=_builder)

1537
pkgs/triton/language/math.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# -------------------
# randint
# -------------------
@triton.jit
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in tl.static_range(n_rounds):
# for _ in range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
# raise key
k0 = k0 + PHILOX_KEY_A
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3
@triton.jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
c0 = c0.to(tl.uint32, bitcast=True)
c1 = c1.to(tl.uint32, bitcast=True)
c2 = c2.to(tl.uint32, bitcast=True)
c3 = c3.to(tl.uint32, bitcast=True)
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
@triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret
@triton.jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
This is the maximally efficient entry point
to Triton's Philox pseudo-random number generator.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
# _0 = tl.zeros(offset.shape, offset.dtype)
_0 = offset * 0
return philox(seed, offset, _0, _0, _0, n_rounds)
# -------------------
# rand
# -------------------
# @triton.jit
# def uint32_to_uniform_float(x):
# """
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
# """
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@triton.jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
x = x.to(tl.int32, bitcast=True)
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
scale = 4.6566127342e-10
x = tl.where(x < 0, -x - 1, x)
return x * scale
@triton.jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
offset = offset.to(tl.uint32, bitcast=True)
source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source)
@triton.jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offsets` block,
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
offsets = offsets.to(tl.uint32, bitcast=True)
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
u3 = uint32_to_uniform_float(i3)
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------
@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
n1, n2 = pair_uniform_to_normal(u1, u2)
n3, n4 = pair_uniform_to_normal(u3, u4)
return n1, n2, n3, n4

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from ..runtime.jit import jit
from . import core
# -----------------------
# Standard library
# -----------------------
@jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@jit
@core._add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + core.exp(-x))
@jit
@core._add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - core.max(x, 0)
num = core.exp(z)
den = core.sum(num, 0)
return core.fdiv(num, den, ieee_rounding)
@jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`.
:param x: the input tensor
:type x: Block
"""
return core.view(x, [x.numel])
@jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = core.minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@jit
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return core.full(shape, 0, dtype)
@jit
def zeros_like(input):
return zeros(input.shape, input.dtype)