First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
__all__ = [
"driver",
"Config",
"Heuristics",
"autotune",
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",
"MockTensor",
"Autotuner",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,305 @@
from __future__ import annotations
import builtins
import time
from typing import Dict
import json
import os
import hashlib
from ..testing import do_bench
from .jit import KernelInterface
from .cache import default_cache_dir
def build_best_config_hash(args_names, key):
cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
hasher = hashlib.sha256()
hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode())
cfg_hash = hasher.hexdigest()
cfg_hash_dir = os.path.join(cache_dir, cfg_hash)
cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config"
cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file)
return cfg_hash_dir, cfg_hash_file
def load_best_config(args_names, key):
_, cfg_hash_file = build_best_config_hash(args_names, key)
if os.path.exists(cfg_hash_file):
with open(cfg_hash_file) as fd:
best_config = json.loads(fd.read())
num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4
num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1
return best_config, num_warps, num_stages
return None
def save_best_config(cfg, args_names, key):
cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key)
if os.path.exists(cfg_hash_dir):
return
os.makedirs(cfg_hash_dir, exist_ok=True)
with open(cfg_hash_file, "w") as fd:
fd.write(
json.dumps(
{
**cfg.kwargs,
"num_warps": cfg.num_warps,
"num_stages": cfg.num_stages,
}
)
)
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
all_args = {**self.nargs, **kwargs}
_args = []
for name in self.arg_names:
if name in all_args:
_args.append(all_args[name])
key = [_args[i] for i in self.key_idx]
divisibility = 16
for arg in args:
if hasattr(arg, "data_ptr"):
key.append(arg.dtype)
key.append(arg.data_ptr() % divisibility == 0)
elif isinstance(arg, int):
key.append(arg)
key = tuple(key)
if key not in self.cache:
load_config = load_best_config(self.arg_names, key)
if load_config:
best_config, num_warps, num_stages = load_config
config = Config(best_config, num_warps, num_stages)
self.cache[key] = config
self.hook(args)
else:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
save_best_config(self.cache[key], self.arg_names, key)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
self.nargs = None
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple times.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
resets the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
return decorator
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values
self.arg_names = arg_names
def run(self, *args, **kwargs):
for v, heur in self.values.items():
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
return self.fn.run(*args, **kwargs)
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)
return decorator

View File

@@ -0,0 +1,131 @@
#include "cuda.h"
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
#define CUDA_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
// Get device handle
CUdevice device;
cuDeviceGet(&device, device_id);
// create a struct to hold device properties
int max_shared_mem;
int multiprocessor_count;
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
max_shared_mem, "multiprocessor_count",
multiprocessor_count, "sm_clock_rate", sm_clock_rate,
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
mem_bus_width);
}
static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}
CUfunction fun;
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
// create driver handles
CUcontext pctx = 0;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided cubin into CUDA driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
NULL, // documentation
-1, // size
ModuleMethods};
PyMODINIT_FUNC PyInit_cuda_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}

View File

@@ -0,0 +1,120 @@
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdio.h>
#include <stdlib.h>
static inline void gpuAssert(hipError_t code, const char *file, int line) {
{
if (code != HIP_SUCCESS) {
{
const char *prefix = "Triton Error [HIP]: ";
const char *str = hipGetErrorString(code);
char err[1024] = {0};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
}
}
#define HIP_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
// create a struct to hold device properties
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
props.sharedMemPerBlock, "multiprocessor_count",
props.multiProcessorCount, "sm_clock_rate",
props.clockRate, "mem_clock_rate", props.memoryClockRate,
"mem_bus_width", props.memoryBusWidth);
}
static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}
// Open HSACO file
FILE *hsaco_file;
if ((hsaco_file = fopen(data, "rb")) == NULL) {
return NULL;
}
// Read HSCAO file into Buffer
fseek(hsaco_file, 0L, SEEK_END);
size_t hsaco_file_size = ftell(hsaco_file);
unsigned char *hsaco =
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char));
rewind(hsaco_file);
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
fclose(hsaco_file);
// set HIP options
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
// launch HIP Binary
hipModule_t mod;
hipFunction_t fun;
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
hipModuleGetFunction(&fun, mod, name);
free(hsaco);
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided hsaco into HIP driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
NULL, // documentation
-1, // size
ModuleMethods};
PyMODINIT_FUNC PyInit_hip_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}

View File

@@ -0,0 +1,131 @@
import json
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
class CacheManager(ABC):
def __init__(self, key):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def has_file(self, filename) -> bool:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename):
if not self.cache_dir:
return False
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
# Invalid group data.
if child_paths is None:
return None
result = {}
for c in child_paths:
p = self._make_path(c)
if not os.path.exists(p):
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
result[c] = p
return result
# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]):
if not self.cache_dir:
return
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
return
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = random.randint(0, 1000000)
# we use the PID incase a bunch of these around so we can see what PID made it
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
return filepath
__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"
def get_cache_manager(key) -> CacheManager:
import os
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
global __cache_cls
global __cache_cls_nme
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
__cache_cls_nme = user_cache_manager
return __cache_cls(key)

View File

@@ -0,0 +1,175 @@
import abc
import hashlib
import os
import tempfile
from pathlib import Path
from ..common.build import _build
from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@staticmethod
def third_party_dir():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "cuda_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("cuda_utils", src_path, tmpdir)
cache.put(src, "main.c", binary=False)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "hip_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("hip_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = HIPUtils()
self.backend = self.HIP
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
def _initialize_obj(self):
if self._obj is None:
self._obj = self._init_fn()
def __getattr__(self, name):
self._initialize_obj()
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
super().__setattr__(name, value)
else:
self._initialize_obj()
setattr(self._obj, name, value)
def __delattr__(self, name):
self._initialize_obj()
delattr(self._obj, name)
def __repr__(self):
if self._obj is None:
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
return repr(self._obj)
def __str__(self):
self._initialize_obj()
return str(self._obj)
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():
return CudaDriver()
else:
return UnsupportedDriver()
driver = LazyProxy(initialize_driver)

View File

@@ -0,0 +1,15 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))

573
pkgs/triton/runtime/jit.py Normal file
View File

@@ -0,0 +1,573 @@
from __future__ import annotations, division
import ast
import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
import triton
def get_disable_sme():
disable_sme = os.getenv("TRITON_DISABLE_SME", default="0")
cc = torch.cuda.get_device_capability()
cc = cc[0] * 10 + cc[1]
if cc == 70: # for ivcore10
disable_sme = "1"
return disable_sme
def get_corex_sme(args, tl_args, enable_sme=True):
can_use_sme = 0
if not enable_sme:
return can_use_sme
import torch
if not (hasattr(torch, "corex") and torch.corex == True):
return can_use_sme
close_sme = get_disable_sme()
if close_sme == "1":
return can_use_sme
index = 0
for i, arg_name in enumerate(args):
arg = args.get(arg_name)
if (i in tl_args):
continue
if (isinstance(arg, int) and arg == 1):
continue
if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8] and arg.dim() >= 2:
dim_M = arg.shape[-2]
dim_K = arg.shape[-1]
sme_dim = 64 / arg.element_size()
if (arg.is_contiguous() and dim_K % sme_dim == 0) or \
(not arg.is_contiguous() and dim_M % sme_dim == 0):
can_use_sme = (1 << index) | can_use_sme
index += 1
return can_use_sme
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
"""
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(*triton.__path__, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
elif isinstance(x, int):
return x % JITFunction.divisibility == 0
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
if key is None:
return '*i8'
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
"float8e5": "fp8e5",
"float8e4": "fp8e4",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
if arg_annotation == '':
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
else (False,)'
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation == 'int':
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
else:
return '(False,)'
def _get_arg_sig_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
if 'Tensor' in arg_annotation:
return f'{arg}.dtype'
elif arg_annotation == 'bool':
return "i1"
elif arg_annotation == 'float':
return 'fp32'
else:
return f'_key_of({arg})'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg)]
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, enable_sme=True, extern_libs=None, stream=None, warmup=False, device=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
use_sme = get_corex_sme({{{grid_args}}}, self.constexprs, enable_sme)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug, use_sme)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device is None:
device = get_current_device()
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
bin = cache[device].get(key, None)
if bin is not None:
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, use_sme=use_sme)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton,
"get_current_device": get_current_device,
"set_current_device": set_current_device,
"get_corex_sme": get_corex_sme}
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = defaultdict(dict)
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
self.hash = None
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if interpret:
from ..debugger.debugger import GridSelector
return GridSelector(fn)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
@staticmethod
def data_ptr():
return 0 # optimistically assumes multiple of 16
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif hasattr(tensor, "data_ptr"):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')