99 lines
2.3 KiB
Python
99 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from ..runtime.jit import jit
|
|
from . import core
|
|
|
|
# -----------------------
|
|
# Standard library
|
|
# -----------------------
|
|
|
|
|
|
@jit
|
|
def cdiv(x, div):
|
|
"""
|
|
Computes the ceiling division of :code:`x` by :code:`div`
|
|
|
|
:param x: the input number
|
|
:type input: Block
|
|
:param div: the divisor
|
|
:param div: Block
|
|
"""
|
|
return (x + div - 1) // div
|
|
|
|
|
|
@jit
|
|
@core._add_math_1arg_docstr("sigmoid")
|
|
def sigmoid(x):
|
|
return 1 / (1 + core.exp(-x))
|
|
|
|
|
|
@jit
|
|
@core._add_math_1arg_docstr("softmax")
|
|
def softmax(x, ieee_rounding=False):
|
|
z = x - core.max(x, 0)
|
|
num = core.exp(z)
|
|
den = core.sum(num, 0)
|
|
return core.fdiv(num, den, ieee_rounding)
|
|
|
|
|
|
@jit
|
|
def ravel(x):
|
|
"""
|
|
Returns a contiguous flattened view of :code:`x`.
|
|
|
|
:param x: the input tensor
|
|
:type x: Block
|
|
"""
|
|
return core.view(x, [x.numel])
|
|
|
|
|
|
@jit
|
|
def swizzle2d(i, j, size_i, size_j, size_g):
|
|
"""
|
|
Transforms indices of a row-major size_i*size_j matrix into those
|
|
of one where indices are row major for each group of size_j rows.
|
|
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
|
[[0 , 1 , 2 , 3 ],
|
|
[4 , 5 , 6 , 7 ],
|
|
[8 , 9 , 10, 11],
|
|
[12, 13, 14, 15]]
|
|
into
|
|
[[0, 2, 4 , 6 ],
|
|
[1, 3, 5 , 7 ],
|
|
[8, 10, 12, 14],
|
|
[9, 11, 13, 15]]
|
|
"""
|
|
# "unrolled index in array"
|
|
ij = i * size_j + j
|
|
# number of elements in `size_g` groups
|
|
# of `size_j` columns
|
|
size_gj = size_g * size_j
|
|
# index of the group in which (i,j) is
|
|
group_id = ij // size_gj
|
|
# row-index of the first element of this group
|
|
off_i = group_id * size_g
|
|
# last group may have fewer rows
|
|
size_g = core.minimum(size_i - off_i, size_g)
|
|
# new row and column indices
|
|
new_i = off_i + (ij % size_g)
|
|
new_j = (ij % size_gj) // size_g
|
|
return new_i, new_j
|
|
|
|
|
|
@jit
|
|
def zeros(shape, dtype):
|
|
"""
|
|
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
|
|
|
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
|
:type shape: tuple of ints
|
|
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
|
:type dtype: DType
|
|
"""
|
|
return core.full(shape, 0, dtype)
|
|
|
|
|
|
@jit
|
|
def zeros_like(input):
|
|
return zeros(input.shape, input.dtype)
|