First commit
This commit is contained in:
98
pkgs/triton/language/standard.py
Normal file
98
pkgs/triton/language/standard.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..runtime.jit import jit
|
||||
from . import core
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
Computes the ceiling division of :code:`x` by :code:`div`
|
||||
|
||||
:param x: the input number
|
||||
:type input: Block
|
||||
:param div: the divisor
|
||||
:param div: Block
|
||||
"""
|
||||
return (x + div - 1) // div
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("sigmoid")
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + core.exp(-x))
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - core.max(x, 0)
|
||||
num = core.exp(z)
|
||||
den = core.sum(num, 0)
|
||||
return core.fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@jit
|
||||
def ravel(x):
|
||||
"""
|
||||
Returns a contiguous flattened view of :code:`x`.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return core.view(x, [x.numel])
|
||||
|
||||
|
||||
@jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
into
|
||||
[[0, 2, 4 , 6 ],
|
||||
[1, 3, 5 , 7 ],
|
||||
[8, 10, 12, 14],
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = core.minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
|
||||
|
||||
@jit
|
||||
def zeros(shape, dtype):
|
||||
"""
|
||||
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
return core.full(shape, 0, dtype)
|
||||
|
||||
|
||||
@jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
Reference in New Issue
Block a user