First commit
This commit is contained in:
201
pkgs/triton/language/__init__.py
Normal file
201
pkgs/triton/language/__init__.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""isort:skip_file"""
|
||||
# Import order is significant here.
|
||||
|
||||
from . import math
|
||||
from . import extra
|
||||
from .standard import (
|
||||
cdiv,
|
||||
sigmoid,
|
||||
softmax,
|
||||
ravel,
|
||||
swizzle2d,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .core import (
|
||||
abs,
|
||||
advance,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
atomic_max,
|
||||
atomic_min,
|
||||
atomic_or,
|
||||
atomic_xchg,
|
||||
atomic_xor,
|
||||
bfloat16,
|
||||
block_type,
|
||||
broadcast,
|
||||
broadcast_to,
|
||||
cat,
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
device_assert,
|
||||
device_print,
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
expand_dims,
|
||||
full,
|
||||
fdiv,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
float8e4,
|
||||
float8e5,
|
||||
function_type,
|
||||
int1,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
int8,
|
||||
load,
|
||||
log,
|
||||
make_block_ptr,
|
||||
max,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
multiple_of,
|
||||
num_programs,
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
program_id,
|
||||
reduce,
|
||||
reshape,
|
||||
sin,
|
||||
sqrt,
|
||||
static_assert,
|
||||
static_print,
|
||||
store,
|
||||
sum,
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
philox,
|
||||
philox_impl,
|
||||
rand,
|
||||
rand4x,
|
||||
randint,
|
||||
randint4x,
|
||||
randn,
|
||||
randn4x,
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"abs",
|
||||
"advance",
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
"atomic_max",
|
||||
"atomic_min",
|
||||
"atomic_or",
|
||||
"atomic_xchg",
|
||||
"atomic_xor",
|
||||
"bfloat16",
|
||||
"block_type",
|
||||
"broadcast",
|
||||
"broadcast_to",
|
||||
"builtin",
|
||||
"cat",
|
||||
"cdiv",
|
||||
"constexpr",
|
||||
"cos",
|
||||
"debug_barrier",
|
||||
"device_assert",
|
||||
"device_print",
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"expand_dims",
|
||||
"extra",
|
||||
"fdiv",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"float8e4",
|
||||
"float8e5",
|
||||
"full",
|
||||
"function_type",
|
||||
"int1",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"math",
|
||||
"load",
|
||||
"log",
|
||||
"make_block_ptr",
|
||||
"max",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
"min",
|
||||
"minimum",
|
||||
"multiple_of",
|
||||
"num_programs",
|
||||
"pair_uniform_to_normal",
|
||||
"philox",
|
||||
"philox_impl",
|
||||
"pi32_t",
|
||||
"pointer_type",
|
||||
"program_id",
|
||||
"rand",
|
||||
"rand4x",
|
||||
"randint",
|
||||
"randint4x",
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reduce",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"static_range",
|
||||
"static_assert",
|
||||
"static_print",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
"tensor",
|
||||
"trans",
|
||||
"triton",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32_to_uniform_float",
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"view",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
]
|
||||
BIN
pkgs/triton/language/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/core.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/core.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/math.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/math.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/random.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/random.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/semantic.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/semantic.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/standard.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/standard.cpython-310.pyc
Normal file
Binary file not shown.
1729
pkgs/triton/language/core.py
Normal file
1729
pkgs/triton/language/core.py
Normal file
File diff suppressed because it is too large
Load Diff
3
pkgs/triton/language/extra/__init__.py
Normal file
3
pkgs/triton/language/extra/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import cuda
|
||||
|
||||
__all__ = ['cuda']
|
||||
BIN
pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/extra/cuda.bc
Normal file
BIN
pkgs/triton/language/extra/cuda.bc
Normal file
Binary file not shown.
19
pkgs/triton/language/extra/cuda.py
Normal file
19
pkgs/triton/language/extra/cuda.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
from .. import core
|
||||
|
||||
__path__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("globaltimer", core.dtype("int64")),
|
||||
}, is_pure=False, _builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("smid", core.dtype("int32")),
|
||||
}, is_pure=True, _builder=_builder)
|
||||
1537
pkgs/triton/language/math.py
Normal file
1537
pkgs/triton/language/math.py
Normal file
File diff suppressed because it is too large
Load Diff
178
pkgs/triton/language/random.py
Normal file
178
pkgs/triton/language/random.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
|
||||
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
|
||||
# -------------------
|
||||
# randint
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
"""
|
||||
for _ in tl.static_range(n_rounds):
|
||||
# for _ in range(n_rounds):
|
||||
# update random state
|
||||
A = PHILOX_ROUND_A
|
||||
B = PHILOX_ROUND_B
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
# raise key
|
||||
k0 = k0 + PHILOX_KEY_A
|
||||
k1 = k1 + PHILOX_KEY_B
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
seed = seed.to(tl.uint64)
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
||||
c0 = c0.to(tl.uint32, bitcast=True)
|
||||
c1 = c1.to(tl.uint32, bitcast=True)
|
||||
c2 = c2.to(tl.uint32, bitcast=True)
|
||||
c3 = c3.to(tl.uint32, bitcast=True)
|
||||
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||
block of random :code:`int32`.
|
||||
|
||||
If you need multiple streams of random numbers,
|
||||
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||
blocks of random :code:`int32`.
|
||||
|
||||
This is the maximally efficient entry point
|
||||
to Triton's Philox pseudo-random number generator.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
# _0 = tl.zeros(offset.shape, offset.dtype)
|
||||
_0 = offset * 0
|
||||
return philox(seed, offset, _0, _0, _0, n_rounds)
|
||||
|
||||
|
||||
# -------------------
|
||||
# rand
|
||||
# -------------------
|
||||
|
||||
# @triton.jit
|
||||
# def uint32_to_uniform_float(x):
|
||||
# """
|
||||
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
# """
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
@triton.jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
"""
|
||||
x = x.to(tl.int32, bitcast=True)
|
||||
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
||||
scale = 4.6566127342e-10
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`U(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offset = offset.to(tl.uint32, bitcast=True)
|
||||
source = randint(seed, offset, n_rounds)
|
||||
return uint32_to_uniform_float(source)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offsets = offsets.to(tl.uint32, bitcast=True)
|
||||
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
u3 = uint32_to_uniform_float(i3)
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
th = 6.283185307179586 * u2
|
||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||
return n1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
|
||||
n1, n2 = pair_uniform_to_normal(u1, u2)
|
||||
n3, n4 = pair_uniform_to_normal(u3, u4)
|
||||
return n1, n2, n3, n4
|
||||
1491
pkgs/triton/language/semantic.py
Normal file
1491
pkgs/triton/language/semantic.py
Normal file
File diff suppressed because it is too large
Load Diff
98
pkgs/triton/language/standard.py
Normal file
98
pkgs/triton/language/standard.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..runtime.jit import jit
|
||||
from . import core
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
Computes the ceiling division of :code:`x` by :code:`div`
|
||||
|
||||
:param x: the input number
|
||||
:type input: Block
|
||||
:param div: the divisor
|
||||
:param div: Block
|
||||
"""
|
||||
return (x + div - 1) // div
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("sigmoid")
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + core.exp(-x))
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - core.max(x, 0)
|
||||
num = core.exp(z)
|
||||
den = core.sum(num, 0)
|
||||
return core.fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@jit
|
||||
def ravel(x):
|
||||
"""
|
||||
Returns a contiguous flattened view of :code:`x`.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return core.view(x, [x.numel])
|
||||
|
||||
|
||||
@jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
into
|
||||
[[0, 2, 4 , 6 ],
|
||||
[1, 3, 5 , 7 ],
|
||||
[8, 10, 12, 14],
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = core.minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
|
||||
|
||||
@jit
|
||||
def zeros(shape, dtype):
|
||||
"""
|
||||
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
return core.full(shape, 0, dtype)
|
||||
|
||||
|
||||
@jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
Reference in New Issue
Block a user