Files
2025-08-05 19:02:46 +08:00

99 lines
2.3 KiB
Python

from __future__ import annotations
from ..runtime.jit import jit
from . import core
# -----------------------
# Standard library
# -----------------------
@jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@jit
@core._add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + core.exp(-x))
@jit
@core._add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - core.max(x, 0)
num = core.exp(z)
den = core.sum(num, 0)
return core.fdiv(num, den, ieee_rounding)
@jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`.
:param x: the input tensor
:type x: Block
"""
return core.view(x, [x.numel])
@jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = core.minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@jit
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return core.full(shape, 0, dtype)
@jit
def zeros_like(input):
return zeros(input.shape, input.dtype)