First commit
This commit is contained in:
4
pkgs/triton/compiler/__init__.py
Normal file
4
pkgs/triton/compiler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .compiler import CompiledKernel, compile
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "CompiledKernel", "CompilationError"]
|
||||
BIN
pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc
Normal file
Binary file not shown.
1133
pkgs/triton/compiler/code_generator.py
Normal file
1133
pkgs/triton/compiler/code_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
631
pkgs/triton/compiler/compiler.py
Normal file
631
pkgs/triton/compiler/compiler.py
Normal file
@@ -0,0 +1,631 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..runtime import driver
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..tools.disasm import extract
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
|
||||
|
||||
def is_corex():
|
||||
import torch
|
||||
return hasattr(torch, "corex") and torch.corex == True
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 64 if is_corex() else 32
|
||||
|
||||
def inline_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def ttir_compute_capability_rewrite(mod, arch):
|
||||
# For hardware without support, we must rewrite all load/store
|
||||
# with block (tensor) pointers into tensors of pointers
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
if _is_cuda(arch):
|
||||
pm.add_rewrite_tensor_pointer_pass(arch)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttir(mod, arch):
|
||||
mod = inline_triton_ir(mod)
|
||||
mod = ttir_compute_capability_rewrite(mod, arch)
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, arch, use_sme = 0):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if _is_cuda(arch):
|
||||
pm.add_tritongpu_accelerate_matmul_pass(arch, use_sme)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
if is_hip() and gpu_has_mfma():
|
||||
pm.add_tritongpu_accelerate_matmul_pass(80)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if is_corex():
|
||||
pm.add_tritongpu_matmul_smeload_pass(arch) #BI 70 MR > 71,only MR support sme
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
# TODO enable this pass for AMD GPU when it is ready
|
||||
if not is_hip():
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
if not is_corex():
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _add_external_libs(mod, libs):
|
||||
for name, path in libs.items():
|
||||
if len(name) == 0 or len(path) == 0:
|
||||
return
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def ttgir_to_llir(mod, extern_libs, arch):
|
||||
if extern_libs:
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
|
||||
else:
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
|
||||
|
||||
# PTX translation
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
Get the highest PTX version supported by the current CUDA driver.
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
if major == 12:
|
||||
return 80 + minor
|
||||
if major == 11:
|
||||
return 70 + minor
|
||||
if major == 10:
|
||||
return 63 + minor
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
|
||||
]
|
||||
|
||||
for ptxas in paths:
|
||||
if os.path.exists(ptxas) and os.path.isfile(ptxas):
|
||||
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
|
||||
|
||||
def llir_to_cubin(mod: Any, arch: int):
|
||||
'''
|
||||
Compile LLVM module to cubin.
|
||||
:param mod: a LLVM module
|
||||
:return: str
|
||||
'''
|
||||
return _triton.translate_llvmir_to_cubin(mod, arch)
|
||||
|
||||
|
||||
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: PTX code
|
||||
'''
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, arch: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param compute_capability: compute capability
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
|
||||
|
||||
# AMDGCN translation
|
||||
|
||||
def get_amdgcn_bitcode_paths(arch):
|
||||
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
|
||||
"ocml.bc",
|
||||
"ockl.bc",
|
||||
"oclc_finite_only_off.bc",
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc",
|
||||
"oclc_abi_version_400.bc",]
|
||||
|
||||
gfx_arch = arch[1]
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
|
||||
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
|
||||
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 0
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
|
||||
i += 1
|
||||
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
|
||||
if os.path.exists(bc_gfx_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
|
||||
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
"""
|
||||
get the amdgpu fulll ISA details for compiling:
|
||||
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
|
||||
"""
|
||||
try:
|
||||
# TODO: package rocm.cc with Triton
|
||||
arch_info = _triton.get_arch_info()
|
||||
warpsize = _triton.get_warp_size()
|
||||
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
|
||||
arch_triple = gfx_arch_details[0]
|
||||
arch_name_features = gfx_arch_details[1].split(':')
|
||||
arch_name = arch_name_features[0]
|
||||
arch_features = ""
|
||||
|
||||
if (len(arch_name_features) == 3):
|
||||
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
|
||||
"-" + re.search('\\w+', arch_name_features[2]).group(0)
|
||||
return [arch_triple, arch_name, arch_features, warpsize]
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# compiler
|
||||
# ------------------------------------------------------------------------------
|
||||
def get_kernel_name(src: str, pattern: str, llir: bool = False) -> str:
|
||||
'''
|
||||
Get kernel name from PTX code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
||||
assert src
|
||||
for line in src.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith(pattern):
|
||||
if not llir:
|
||||
return line.split()[-1]
|
||||
return line.split("(")[0].split("@")[-1]
|
||||
|
||||
|
||||
def convert_type_repr(x):
|
||||
match = re.search(r'!tt\.ptr<(.*)>', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, arch, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
debug = kwargs.get("debug", False)
|
||||
use_sme = kwargs.get("use_sme", 0)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}-{use_sme}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
"ttgir": mlir_prototype_pattern,
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
"ttgir": mlir_arg_type_pattern,
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
serialized_constants[constant] = constants[constant]
|
||||
return serialized_constants
|
||||
|
||||
|
||||
def parse_mlir_module(path, context):
|
||||
module = _triton.ir.parse_mlir_module(path, context)
|
||||
# module takes ownership of the context
|
||||
module.context = context
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
# TODO: architecture descriptor class
|
||||
def _is_cuda(arch):
|
||||
return isinstance(arch, int)
|
||||
|
||||
def is_hip():
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
from ..language.semantic import gpu_has_mfma
|
||||
|
||||
def get_architecture_descriptor(capability):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if capability is None:
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
else:
|
||||
capability = get_amdgpu_arch_fulldetails()
|
||||
return capability
|
||||
|
||||
|
||||
def add_rocm_stages(arch, extern_libs, stages):
|
||||
extern_libs.update(get_amdgcn_bitcode_paths(arch))
|
||||
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
gfx_arch_full_details = arch
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
stages["amdgcn"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
gfx_arch_full_details[2]))
|
||||
|
||||
|
||||
def add_cuda_stages(arch, extern_libs, stages):
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, arch))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, arch))
|
||||
|
||||
|
||||
def add_iluvatar_stages(arch, extern_libs, stages):
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: llir_to_cubin(src, arch))
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
if is_hip():
|
||||
capability = None
|
||||
else:
|
||||
capability = kwargs.get("cc", None)
|
||||
arch = get_architecture_descriptor(capability)
|
||||
is_cuda = _is_cuda(arch)
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
|
||||
context = _triton.ir.context()
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
use_sme = kwargs.get("use_sme", 0)
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
debug = kwargs.get("debug", False)
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch, use_sme))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch))
|
||||
if is_cuda:
|
||||
if is_corex():
|
||||
add_iluvatar_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
add_rocm_stages(arch, extern_libs, stages)
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
signature = kwargs["signature"]
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
kwargs["configs"] = configs
|
||||
name = fn.__name__
|
||||
first_stage = 0
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
kwargs["signature"] = signature
|
||||
else:
|
||||
assert isinstance(fn, str)
|
||||
_, ir = os.path.basename(fn).split(".")
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
if ir == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
if metadata_path is not None:
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_stages": num_stages,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"use_sme": use_sme}
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
ir_filename = f"{name}.{ir}"
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
else:
|
||||
path = metadata_group.get(ir_filename)
|
||||
if path is None:
|
||||
next_module = compile_kernel(module)
|
||||
if ir == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename)
|
||||
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
|
||||
else:
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
else:
|
||||
if ir == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
hasco_path = metadata_group.get(extra_file_name)
|
||||
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
|
||||
next_module = (parse(path), parse(hasco_path))
|
||||
else:
|
||||
next_module = parse(path)
|
||||
|
||||
if ir == "llir":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern="define iluvatar_kernel void", llir=True)
|
||||
if ir == "cubin":
|
||||
asm[ir] = next_module
|
||||
elif ir == "amdgcn":
|
||||
asm[ir] = str(next_module[0])
|
||||
else:
|
||||
asm[ir] = str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
if ir == "amdgcn":
|
||||
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
module = next_module
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
return CompiledKernel(fn, so_path, metadata, asm)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
# Hooks for external tools to monitor the execution of triton kernels
|
||||
launch_enter_hook = None
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn, so_path, metadata, asm):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.fn = fn
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.warp_size = metadata["warp_size"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.constants = metadata["constants"]
|
||||
# initialize asm dict
|
||||
self.asm = asm
|
||||
# binaries are lazily initialized
|
||||
# because it involves doing runtime things
|
||||
# (e.g., checking amount of shared memory on current device)
|
||||
self.metadata = metadata
|
||||
self.cu_module = None
|
||||
self.cu_function = None
|
||||
|
||||
def _init_handles(self):
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
|
||||
|
||||
self.n_spills = n_spills
|
||||
self.n_regs = n_regs
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'c_wrapper':
|
||||
self._init_handles()
|
||||
return super().__getattribute__(name)
|
||||
|
||||
def __getitem__(self, grid):
|
||||
self._init_handles()
|
||||
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = triton.runtime.jit.get_cuda_stream()
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
return self.asm['sass']
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
52
pkgs/triton/compiler/errors.py
Normal file
52
pkgs/triton/compiler/errors.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import ast
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
source_line_count_max_in_message = 12
|
||||
|
||||
def _format_message(self) -> str:
|
||||
node = self.node
|
||||
if self.src is None:
|
||||
source_excerpt = " <source unavailable>"
|
||||
else:
|
||||
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
|
||||
if source_excerpt:
|
||||
source_excerpt.append(' ' * node.col_offset + '^')
|
||||
source_excerpt = '\n'.join(source_excerpt)
|
||||
else:
|
||||
source_excerpt = " <source empty>"
|
||||
|
||||
message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt)
|
||||
if self.error_message:
|
||||
message += '\n' + self.error_message
|
||||
return message
|
||||
|
||||
def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, None]):
|
||||
self.src = src
|
||||
self.node = node
|
||||
self.error_message = error_message
|
||||
self.message = self._format_message()
|
||||
|
||||
def set_source_code(self, src: Optional[str]):
|
||||
self.src = src
|
||||
self.message = self._format_message()
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({!r})".format(type(self).__name__, self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return type(self), (self.src, self.node, self.error_message)
|
||||
|
||||
|
||||
class CompileTimeAssertionFailure(CompilationError):
|
||||
"""Specific exception for failed tests in `static_assert` invocations"""
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedLanguageConstruct(CompilationError):
|
||||
pass
|
||||
392
pkgs/triton/compiler/make_launcher.py
Normal file
392
pkgs/triton/compiler/make_launcher.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
|
||||
|
||||
def is_hip():
|
||||
import torch
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_corex():
|
||||
import torch
|
||||
return hasattr(torch, "corex") and torch.corex == True
|
||||
|
||||
|
||||
# ----- stub --------
|
||||
|
||||
|
||||
def make_so_cache_key(version_hash, signature, constants):
|
||||
# Get unique key for the compiled code
|
||||
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||
key = f"{version_hash}-{''.join(signature.values())}{constants}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def make_stub(name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
cache_path = so_cache_manager.get_file(so_name)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher(constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(name, src_path, tmpdir)
|
||||
so_cache_manager.put(src, f"{name}.c", binary=False)
|
||||
with open(so, "rb") as f:
|
||||
return so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "hipDeviceptr_t" if is_hip() else "CUdeviceptr"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
"i16": "int16_t",
|
||||
"i32": "int32_t",
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp16": "float",
|
||||
"bf16": "float",
|
||||
"fp32": "float",
|
||||
"f32": "float",
|
||||
"fp64": "double",
|
||||
}[ty]
|
||||
|
||||
|
||||
def generate_launcher(constants, signature):
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
if ty[0] == '*':
|
||||
return "PyObject*"
|
||||
return {
|
||||
'i1': 'int32_t',
|
||||
'i32': 'int32_t',
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp16': 'float',
|
||||
'bf16': 'float',
|
||||
'fp32': 'float',
|
||||
'f32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
def format_of(ty):
|
||||
return {
|
||||
"PyObject*": "O",
|
||||
"float": "f",
|
||||
"double": "d",
|
||||
"long": "l",
|
||||
"uint32_t": "I",
|
||||
"int32_t": "i",
|
||||
"uint64_t": "K",
|
||||
"int64_t": "L",
|
||||
}[ty]
|
||||
|
||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
if is_hip():
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static int getWarpSize(hipStream_t stream)
|
||||
{{
|
||||
int device_id = hipGetStreamDeviceId(stream);
|
||||
gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__);
|
||||
hipDeviceProp_t prop;
|
||||
HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
|
||||
return prop.warpSize;
|
||||
}}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
int warp_size = getWarpSize(stream);
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
|
||||
if (ptr) {{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
|
||||
if (!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
if (PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
warp_size = 64 if is_corex() else 32
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * {warp_size}, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
CUdeviceptr dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
||||
if(!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
/*
|
||||
uint64_t dev_ptr;
|
||||
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
ptr_info.dev_ptr = dev_ptr;
|
||||
*/
|
||||
Py_DECREF(ret); // Thanks ChatGPT!
|
||||
return ptr_info;
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
return src
|
||||
Reference in New Issue
Block a user