import hashlib import os import tempfile from ..common import _build from ..runtime.cache import get_cache_manager from ..runtime.jit import version_key def is_hip(): import torch return torch.version.hip is not None def is_corex(): import torch return hasattr(torch, "corex") and torch.corex == True # ----- stub -------- def make_so_cache_key(version_hash, signature, constants): # Get unique key for the compiled code signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} key = f"{version_hash}-{''.join(signature.values())}{constants}" key = hashlib.md5(key.encode("utf-8")).hexdigest() return key def make_stub(name, signature, constants): # name of files that are cached so_cache_key = make_so_cache_key(version_key(), signature, constants) so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.so" # retrieve stub from cache if it exists cache_path = so_cache_manager.get_file(so_name) if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src = generate_launcher(constants, signature) src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) so = _build(name, src_path, tmpdir) so_cache_manager.put(src, f"{name}.c", binary=False) with open(so, "rb") as f: return so_cache_manager.put(f.read(), so_name, binary=True) else: return cache_path # ----- source code generation -------- def ty_to_cpp(ty): if ty[0] == '*': return "hipDeviceptr_t" if is_hip() else "CUdeviceptr" return { "i1": "int32_t", "i8": "int8_t", "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", "bf16": "float", "fp32": "float", "f32": "float", "fp64": "double", }[ty] def generate_launcher(constants, signature): arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): if ty[0] == '*': return "PyObject*" return { 'i1': 'int32_t', 'i32': 'int32_t', 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', 'fp16': 'float', 'bf16': 'float', 'fp32': 'float', 'f32': 'float', 'fp64': 'double', }[ty] def format_of(ty): return { "PyObject*": "O", "float": "f", "double": "d", "long": "l", "uint32_t": "I", "int32_t": "i", "uint64_t": "K", "int64_t": "L", }[ty] format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) # generate glue code if is_hip(): src = f""" #define __HIP_PLATFORM_AMD__ #include #include #include static inline void gpuAssert(hipError_t code, const char *file, int line) {{ if (code != HIP_SUCCESS) {{ const char* prefix = "Triton Error [HIP]: "; const char* str = hipGetErrorString(code); char err[1024] = {{0}}; snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); PyErr_SetString(PyExc_RuntimeError, err); }} }} #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} static int getWarpSize(hipStream_t stream) {{ int device_id = hipGetStreamDeviceId(stream); gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__); hipDeviceProp_t prop; HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); return prop.warpSize; }} static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; if (gridX*gridY*gridZ > 0) {{ int warp_size = getWarpSize(stream); HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0)); }} }} typedef struct _DevicePtrInfo {{ hipDeviceptr_t dev_ptr; bool valid; }} DevicePtrInfo; static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ DevicePtrInfo ptr_info; ptr_info.dev_ptr = 0; ptr_info.valid = true; if (PyLong_Check(obj)) {{ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); return ptr_info; }} if (obj == Py_None) {{ // valid nullptr return ptr_info; }} PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); if (ptr) {{ PyObject *empty_tuple = PyTuple_New(0); PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(ptr); if (!PyLong_Check(ret)) {{ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); ptr_info.valid = false; return ptr_info; }} ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); if (!ptr_info.dev_ptr) return ptr_info; uint64_t dev_ptr; hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); if (status == hipErrorInvalidValue) {{ PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; }} ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); return ptr_info; }} static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; int num_warps; int shared_memory; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *compiled_kernel = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ return NULL; }} if (launch_enter_hook != Py_None) {{ PyObject_CallObject(launch_enter_hook, args); }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; _launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); if (launch_exit_hook != Py_None) {{ PyObject_CallObject(launch_exit_hook, args); }} if (PyErr_Occurred()) {{ return NULL; }} // return None Py_INCREF(Py_None); return Py_None; }} static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; static struct PyModuleDef ModuleDef = {{ PyModuleDef_HEAD_INIT, \"__triton_launcher\", NULL, //documentation -1, //size ModuleMethods }}; PyMODINIT_FUNC PyInit___triton_launcher(void) {{ PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; }} PyModule_AddFunctions(m, ModuleMethods); return m; }} """ else: warp_size = 64 if is_corex() else 32 src = f""" #include \"cuda.h\" #include #include static inline void gpuAssert(CUresult code, const char *file, int line) {{ if (code != CUDA_SUCCESS) {{ const char* prefix = "Triton Error [CUDA]: "; const char* str; cuGetErrorString(code, &str); char err[1024] = {{0}}; strcat(err, prefix); strcat(err, str); PyErr_SetString(PyExc_RuntimeError, err); }} }} #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; if(gridX*gridY*gridZ > 0){{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * {warp_size}, 1, 1, shared_memory, stream, params, 0)); }} }} typedef struct _DevicePtrInfo {{ CUdeviceptr dev_ptr; bool valid; }} DevicePtrInfo; static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ DevicePtrInfo ptr_info; ptr_info.dev_ptr = 0; ptr_info.valid = true; if (PyLong_Check(obj)) {{ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); return ptr_info; }} if (obj == Py_None) {{ // valid nullptr return ptr_info; }} PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); if(ptr){{ PyObject *empty_tuple = PyTuple_New(0); PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(ptr); if (!PyLong_Check(ret)) {{ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); ptr_info.valid = false; return ptr_info; }} ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); if(!ptr_info.dev_ptr) return ptr_info; /* uint64_t dev_ptr; int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); if (status == CUDA_ERROR_INVALID_VALUE) {{ PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; }} ptr_info.dev_ptr = dev_ptr; */ Py_DECREF(ret); // Thanks ChatGPT! return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); return ptr_info; }} static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; int num_warps; int shared_memory; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *compiled_kernel = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ return NULL; }} if (launch_enter_hook != Py_None) {{ PyObject_CallObject(launch_enter_hook, args); }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); if (launch_exit_hook != Py_None) {{ PyObject_CallObject(launch_exit_hook, args); }} if(PyErr_Occurred()) {{ return NULL; }} // return None Py_INCREF(Py_None); return Py_None; }} static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; static struct PyModuleDef ModuleDef = {{ PyModuleDef_HEAD_INIT, \"__triton_launcher\", NULL, //documentation -1, //size ModuleMethods }}; PyMODINIT_FUNC PyInit___triton_launcher(void) {{ PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; }} PyModule_AddFunctions(m, ModuleMethods); return m; }} """ return src