First commit
This commit is contained in:
21
pkgs/triton/runtime/__init__.py
Normal file
21
pkgs/triton/runtime/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
"Config",
|
||||
"Heuristics",
|
||||
"autotune",
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
"MockTensor",
|
||||
"Autotuner",
|
||||
]
|
||||
BIN
pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc
Normal file
Binary file not shown.
305
pkgs/triton/runtime/autotuner.py
Normal file
305
pkgs/triton/runtime/autotuner.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from ..testing import do_bench
|
||||
from .jit import KernelInterface
|
||||
from .cache import default_cache_dir
|
||||
|
||||
|
||||
def build_best_config_hash(args_names, key):
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode())
|
||||
cfg_hash = hasher.hexdigest()
|
||||
cfg_hash_dir = os.path.join(cache_dir, cfg_hash)
|
||||
cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config"
|
||||
cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file)
|
||||
return cfg_hash_dir, cfg_hash_file
|
||||
|
||||
|
||||
def load_best_config(args_names, key):
|
||||
_, cfg_hash_file = build_best_config_hash(args_names, key)
|
||||
if os.path.exists(cfg_hash_file):
|
||||
with open(cfg_hash_file) as fd:
|
||||
best_config = json.loads(fd.read())
|
||||
num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4
|
||||
num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1
|
||||
return best_config, num_warps, num_stages
|
||||
return None
|
||||
|
||||
|
||||
def save_best_config(cfg, args_names, key):
|
||||
cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key)
|
||||
if os.path.exists(cfg_hash_dir):
|
||||
return
|
||||
os.makedirs(cfg_hash_dir, exist_ok=True)
|
||||
with open(cfg_hash_file, "w") as fd:
|
||||
fd.write(
|
||||
json.dumps(
|
||||
{
|
||||
**cfg.kwargs,
|
||||
"num_warps": cfg.num_warps,
|
||||
"num_stages": cfg.num_stages,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
try:
|
||||
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
all_args = {**self.nargs, **kwargs}
|
||||
_args = []
|
||||
for name in self.arg_names:
|
||||
if name in all_args:
|
||||
_args.append(all_args[name])
|
||||
key = [_args[i] for i in self.key_idx]
|
||||
divisibility = 16
|
||||
for arg in args:
|
||||
if hasattr(arg, "data_ptr"):
|
||||
key.append(arg.dtype)
|
||||
key.append(arg.data_ptr() % divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
key.append(arg)
|
||||
key = tuple(key)
|
||||
if key not in self.cache:
|
||||
load_config = load_best_config(self.arg_names, key)
|
||||
if load_config:
|
||||
best_config, num_warps, num_stages = load_config
|
||||
config = Config(best_config, num_warps, num_stages)
|
||||
self.cache[key] = config
|
||||
self.hook(args)
|
||||
else:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
save_best_config(self.cache[key], self.arg_names, key)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
self.nargs = None
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
for config in self.prune_configs(kwargs):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||
cooperatively execute using `8 * 32 = 256` threads.
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple times.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
resets the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
self.arg_names = arg_names
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
for v, heur in self.values.items():
|
||||
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||
return self.fn.run(*args, **kwargs)
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
return decorator
|
||||
131
pkgs/triton/runtime/backends/cuda.c
Normal file
131
pkgs/triton/runtime/backends/cuda.c
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "cuda.h"
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code != CUDA_SUCCESS) {
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
if (PyErr_Occurred()) \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
// Get device handle
|
||||
CUdevice device;
|
||||
cuDeviceGet(&device, device_id);
|
||||
|
||||
// create a struct to hold device properties
|
||||
int max_shared_mem;
|
||||
int multiprocessor_count;
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
max_shared_mem, "multiprocessor_count",
|
||||
multiprocessor_count, "sm_clock_rate", sm_clock_rate,
|
||||
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
|
||||
mem_bus_width);
|
||||
}
|
||||
|
||||
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
const char *name;
|
||||
const char *data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
||||
&device)) {
|
||||
return NULL;
|
||||
}
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
int32_t n_regs = 0;
|
||||
int32_t n_spills = 0;
|
||||
// create driver handles
|
||||
CUcontext pctx = 0;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load provided cubin into CUDA driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
|
||||
NULL, // documentation
|
||||
-1, // size
|
||||
ModuleMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
120
pkgs/triton/runtime/backends/hip.c
Normal file
120
pkgs/triton/runtime/backends/hip.c
Normal file
@@ -0,0 +1,120 @@
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
||||
{
|
||||
if (code != HIP_SUCCESS) {
|
||||
{
|
||||
const char *prefix = "Triton Error [HIP]: ";
|
||||
const char *str = hipGetErrorString(code);
|
||||
char err[1024] = {0};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define HIP_CHECK(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
if (PyErr_Occurred()) \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
|
||||
|
||||
// create a struct to hold device properties
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
props.sharedMemPerBlock, "multiprocessor_count",
|
||||
props.multiProcessorCount, "sm_clock_rate",
|
||||
props.clockRate, "mem_clock_rate", props.memoryClockRate,
|
||||
"mem_bus_width", props.memoryBusWidth);
|
||||
}
|
||||
|
||||
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
const char *name;
|
||||
const char *data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
||||
&device)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Open HSACO file
|
||||
FILE *hsaco_file;
|
||||
if ((hsaco_file = fopen(data, "rb")) == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Read HSCAO file into Buffer
|
||||
fseek(hsaco_file, 0L, SEEK_END);
|
||||
size_t hsaco_file_size = ftell(hsaco_file);
|
||||
unsigned char *hsaco =
|
||||
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char));
|
||||
rewind(hsaco_file);
|
||||
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
|
||||
fclose(hsaco_file);
|
||||
|
||||
// set HIP options
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
||||
hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes,
|
||||
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
||||
const unsigned int errbufsize = 8192;
|
||||
const unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
||||
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
||||
|
||||
// launch HIP Binary
|
||||
hipModule_t mod;
|
||||
hipFunction_t fun;
|
||||
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
|
||||
hipModuleGetFunction(&fun, mod, name);
|
||||
free(hsaco);
|
||||
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load provided hsaco into HIP driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
|
||||
NULL, // documentation
|
||||
-1, // size
|
||||
ModuleMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_hip_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
131
pkgs/triton/runtime/cache.py
Normal file
131
pkgs/triton/runtime/cache.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def default_cache_dir():
|
||||
return os.path.join(Path.home(), ".triton", "cache")
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_file(self, filename) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
pass
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
if self.has_file(filename):
|
||||
return self._make_path(filename)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
||||
grp_filename = f"__grp__{filename}"
|
||||
if not self.has_file(grp_filename):
|
||||
return None
|
||||
grp_filepath = self._make_path(grp_filename)
|
||||
with open(grp_filepath) as f:
|
||||
grp_data = json.load(f)
|
||||
child_paths = grp_data.get("child_paths", None)
|
||||
# Invalid group data.
|
||||
if child_paths is None:
|
||||
return None
|
||||
result = {}
|
||||
for c in child_paths:
|
||||
p = self._make_path(c)
|
||||
if not os.path.exists(p):
|
||||
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
|
||||
result[c] = p
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
if not self.cache_dir:
|
||||
return
|
||||
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
||||
grp_filename = f"__grp__{filename}"
|
||||
return self.put(grp_contents, grp_filename, binary=False)
|
||||
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
# Random ID to avoid any collisions
|
||||
rnd_id = random.randint(0, 1000000)
|
||||
# we use the PID incase a bunch of these around so we can see what PID made it
|
||||
pid = os.getpid()
|
||||
# use tempfile to be robust against program interruptions
|
||||
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
|
||||
mode = "wb" if binary else "w"
|
||||
with open(temp_path, mode) as f:
|
||||
f.write(data)
|
||||
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
|
||||
# so filepath cannot see a partial write
|
||||
os.replace(temp_path, filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
__cache_cls = FileCacheManager
|
||||
__cache_cls_nme = "DEFAULT"
|
||||
|
||||
|
||||
def get_cache_manager(key) -> CacheManager:
|
||||
import os
|
||||
|
||||
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||
global __cache_cls
|
||||
global __cache_cls_nme
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
__cache_cls_nme = user_cache_manager
|
||||
|
||||
return __cache_cls(key)
|
||||
175
pkgs/triton/runtime/driver.py
Normal file
175
pkgs/triton/runtime/driver.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import abc
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from ..common.build import _build
|
||||
from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@staticmethod
|
||||
def third_party_dir():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = get_cache_manager(key)
|
||||
fname = "cuda_utils.so"
|
||||
cache_path = cache.get_file(fname)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("cuda_utils", src_path, tmpdir)
|
||||
cache.put(src, "main.c", binary=False)
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = get_cache_manager(key)
|
||||
fname = "hip_utils.so"
|
||||
cache_path = cache.get_file(fname)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("hip_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = HIPUtils()
|
||||
self.backend = self.HIP
|
||||
|
||||
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
|
||||
def _initialize_obj(self):
|
||||
if self._obj is None:
|
||||
self._obj = self._init_fn()
|
||||
|
||||
def __getattr__(self, name):
|
||||
self._initialize_obj()
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
setattr(self._obj, name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
self._initialize_obj()
|
||||
delattr(self._obj, name)
|
||||
|
||||
def __repr__(self):
|
||||
if self._obj is None:
|
||||
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
||||
return repr(self._obj)
|
||||
|
||||
def __str__(self):
|
||||
self._initialize_obj()
|
||||
return str(self._obj)
|
||||
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
return CudaDriver()
|
||||
else:
|
||||
return UnsupportedDriver()
|
||||
|
||||
|
||||
driver = LazyProxy(initialize_driver)
|
||||
15
pkgs/triton/runtime/errors.py
Normal file
15
pkgs/triton/runtime/errors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
573
pkgs/triton/runtime/jit.py
Normal file
573
pkgs/triton/runtime/jit.py
Normal file
@@ -0,0 +1,573 @@
|
||||
from __future__ import annotations, division
|
||||
|
||||
import ast
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def get_disable_sme():
|
||||
disable_sme = os.getenv("TRITON_DISABLE_SME", default="0")
|
||||
cc = torch.cuda.get_device_capability()
|
||||
cc = cc[0] * 10 + cc[1]
|
||||
if cc == 70: # for ivcore10
|
||||
disable_sme = "1"
|
||||
|
||||
return disable_sme
|
||||
|
||||
|
||||
def get_corex_sme(args, tl_args, enable_sme=True):
|
||||
can_use_sme = 0
|
||||
if not enable_sme:
|
||||
return can_use_sme
|
||||
import torch
|
||||
if not (hasattr(torch, "corex") and torch.corex == True):
|
||||
return can_use_sme
|
||||
close_sme = get_disable_sme()
|
||||
if close_sme == "1":
|
||||
return can_use_sme
|
||||
index = 0
|
||||
for i, arg_name in enumerate(args):
|
||||
arg = args.get(arg_name)
|
||||
if (i in tl_args):
|
||||
continue
|
||||
if (isinstance(arg, int) and arg == 1):
|
||||
continue
|
||||
if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8] and arg.dim() >= 2:
|
||||
dim_M = arg.shape[-2]
|
||||
dim_K = arg.shape[-1]
|
||||
sme_dim = 64 / arg.element_size()
|
||||
if (arg.is_contiguous() and dim_K % sme_dim == 0) or \
|
||||
(not arg.is_contiguous() and dim_M % sme_dim == 0):
|
||||
can_use_sme = (1 << index) | can_use_sme
|
||||
index += 1
|
||||
return can_use_sme
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
"""
|
||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||
be used to invalidate a JITFunction's hash when its source code -- or
|
||||
that of its dependencies -- changes.
|
||||
"""
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(*triton.__path__, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
def __getitem__(self, grid) -> T:
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
|
||||
@staticmethod
|
||||
def _key_of(arg):
|
||||
if hasattr(arg, "dtype"):
|
||||
return arg.dtype
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(x, int):
|
||||
return x % JITFunction.divisibility == 0
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
if key is None:
|
||||
return '*i8'
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e5": "fp8e5",
|
||||
"float8e4": "fp8e4",
|
||||
"float16": "fp16",
|
||||
"bfloat16": "bf16",
|
||||
"float32": "fp32",
|
||||
"float64": "fp64",
|
||||
"int8": "i8",
|
||||
"int16": "i16",
|
||||
"int32": "i32",
|
||||
"int64": "i64",
|
||||
"uint8": "u8",
|
||||
"uint16": "u16",
|
||||
"uint32": "u32",
|
||||
"uint64": "u64",
|
||||
}
|
||||
# reinterpret can create triton type
|
||||
for v in list(tys.values()):
|
||||
tys[v] = v
|
||||
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
if arg_annotation == '':
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
elif arg_annotation == 'int':
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
|
||||
else:
|
||||
return '(False,)'
|
||||
|
||||
def _get_arg_sig_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
return f'{arg}.dtype'
|
||||
elif arg_annotation == 'bool':
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg)]
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, enable_sme=True, extern_libs=None, stream=None, warmup=False, device=None):
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
use_sme = get_corex_sme({{{grid_args}}}, self.constexprs, enable_sme)
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug, use_sme)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, use_sme=use_sme)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device,
|
||||
"get_corex_sme": get_corex_sme}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = defaultdict(dict)
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
def cache_key(self):
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Our unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
assert len(tree.body) == 1
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@overload
|
||||
def jit(fn: T) -> JITFunction[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
|
||||
def jit(
|
||||
fn: Optional[T] = None,
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
interpret: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, arguments are
|
||||
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
||||
and a `.dtype` attribute.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* builtins within the triton package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
if interpret:
|
||||
from ..debugger.debugger import GridSelector
|
||||
return GridSelector(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTensor:
|
||||
"""
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@staticmethod
|
||||
def data_ptr():
|
||||
return 0 # optimistically assumes multiple of 16
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif hasattr(tensor, "data_ptr"):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
Reference in New Issue
Block a user