Files
2025-08-05 19:02:46 +08:00

632 lines
23 KiB
Python

from __future__ import annotations
import functools
import hashlib
import json
import os
import re
import subprocess
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..tools.disasm import extract
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
def is_corex():
import torch
return hasattr(torch, "corex") and torch.corex == True
CUDA_DEFAULT_WARP_SIZE = 64 if is_corex() else 32
def inline_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.run(mod)
return mod
def ttir_compute_capability_rewrite(mod, arch):
# For hardware without support, we must rewrite all load/store
# with block (tensor) pointers into tensors of pointers
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
if _is_cuda(arch):
pm.add_rewrite_tensor_pointer_pass(arch)
pm.run(mod)
return mod
def optimize_ttir(mod, arch):
mod = inline_triton_ir(mod)
mod = ttir_compute_capability_rewrite(mod, arch)
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_licm_pass()
pm.add_symbol_dce_pass()
pm.run(mod)
return mod
def ttir_to_ttgir(mod, num_warps, warpsize):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, arch, use_sme = 0):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
if _is_cuda(arch):
pm.add_tritongpu_accelerate_matmul_pass(arch, use_sme)
# TODO change interface of accelerate_matmul_pass
if is_hip() and gpu_has_mfma():
pm.add_tritongpu_accelerate_matmul_pass(80)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if is_corex():
pm.add_tritongpu_matmul_smeload_pass(arch) #BI 70 MR > 71,only MR support sme
pm.add_tritongpu_remove_layout_conversions_pass()
# TODO enable this pass for AMD GPU when it is ready
if not is_hip():
pm.add_tritongpu_pipeline_pass(num_stages)
if not is_corex():
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.run(mod)
return mod
def _add_external_libs(mod, libs):
for name, path in libs.items():
if len(name) == 0 or len(path) == 0:
return
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs, arch):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
else:
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
# PTX translation
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
Get the highest PTX version supported by the current CUDA driver.
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
if major == 12:
return 80 + minor
if major == 11:
return 70 + minor
if major == 10:
return 63 + minor
raise RuntimeError("Triton only support CUDA 10.0 or higher")
@functools.lru_cache()
def path_to_ptxas():
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
]
for ptxas in paths:
if os.path.exists(ptxas) and os.path.isfile(ptxas):
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
def llir_to_cubin(mod: Any, arch: int):
'''
Compile LLVM module to cubin.
:param mod: a LLVM module
:return: str
'''
return _triton.translate_llvmir_to_cubin(mod, arch)
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return: PTX code
'''
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
def ptx_to_cubin(ptx: str, arch: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param compute_capability: compute capability
:return: str
'''
ptxas, _ = path_to_ptxas()
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
# AMDGCN translation
def get_amdgcn_bitcode_paths(arch):
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
"ocml.bc",
"ockl.bc",
"oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc",
"oclc_abi_version_400.bc",]
gfx_arch = arch[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 0
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
i += 1
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
if os.path.exists(bc_gfx_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
return amdgcn_bitcode_paths
def get_amdgpu_arch_fulldetails():
"""
get the amdgpu fulll ISA details for compiling:
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
"""
try:
# TODO: package rocm.cc with Triton
arch_info = _triton.get_arch_info()
warpsize = _triton.get_warp_size()
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
arch_triple = gfx_arch_details[0]
arch_name_features = gfx_arch_details[1].split(':')
arch_name = arch_name_features[0]
arch_features = ""
if (len(arch_name_features) == 3):
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
"-" + re.search('\\w+', arch_name_features[2]).group(0)
return [arch_triple, arch_name, arch_features, warpsize]
except BaseException:
return None
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
'''
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
:param mod: a TritonGPU dialect module
:return:
- AMDGCN code
- Path to HSACO object
'''
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
def get_kernel_name(src: str, pattern: str, llir: bool = False) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert src
for line in src.split('\n'):
line = line.strip()
if line.startswith(pattern):
if not llir:
return line.split()[-1]
return line.split("(")[0].split("@")[-1]
def convert_type_repr(x):
match = re.search(r'!tt\.ptr<(.*)>', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, arch, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
debug = kwargs.get("debug", False)
use_sme = kwargs.get("use_sme", 0)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}-{use_sme}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
def _get_jsonable_constants(constants):
def _is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
serialized_constants = {}
for constant in constants:
if _is_jsonable(constants[constant]):
serialized_constants[constant] = constants[constant]
return serialized_constants
def parse_mlir_module(path, context):
module = _triton.ir.parse_mlir_module(path, context)
# module takes ownership of the context
module.context = context
return module
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
# TODO: architecture descriptor class
def _is_cuda(arch):
return isinstance(arch, int)
def is_hip():
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
from ..language.semantic import gpu_has_mfma
def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
else:
capability = get_amdgpu_arch_fulldetails()
return capability
def add_rocm_stages(arch, extern_libs, stages):
extern_libs.update(get_amdgcn_bitcode_paths(arch))
for key in list(extern_libs):
if extern_libs[key] == '' or extern_libs[key] is None:
extern_libs.pop(key)
gfx_arch_full_details = arch
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
stages["amdgcn"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2]))
def add_cuda_stages(arch, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, arch))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, arch))
def add_iluvatar_stages(arch, extern_libs, stages):
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: llir_to_cubin(src, arch))
def compile(fn, **kwargs):
if is_hip():
capability = None
else:
capability = kwargs.get("cc", None)
arch = get_architecture_descriptor(capability)
is_cuda = _is_cuda(arch)
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
context = _triton.ir.context()
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
use_sme = kwargs.get("use_sme", 0)
if extern_libs is None:
extern_libs = dict()
debug = kwargs.get("debug", False)
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch, use_sme))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch))
if is_cuda:
if is_corex():
add_iluvatar_stages(arch, extern_libs, stages)
else:
add_cuda_stages(arch, extern_libs, stages)
else:
add_rocm_stages(arch, extern_libs, stages)
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
kwargs["configs"] = configs
name = fn.__name__
first_stage = 0
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
_, ir = os.path.basename(fn).split(".")
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir], signature)
if ir == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# load metadata if any
metadata = None
metadata_filename = f"{name}.json"
# The group is addressed by the metadata
metadata_group = fn_cache_manager.get_group(
metadata_filename
) or {}
metadata_path = metadata_group.get(metadata_filename)
if metadata_path is not None:
with open(metadata_path) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps,
"warp_size": warp_size,
"num_stages": num_stages,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"use_sme": use_sme}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
ir_filename = f"{name}.{ir}"
if ir == ext:
next_module = parse(fn)
else:
path = metadata_group.get(ir_filename)
if path is None:
next_module = compile_kernel(module)
if ir == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename)
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
else:
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
else:
if ir == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
hasco_path = metadata_group.get(extra_file_name)
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
next_module = (parse(path), parse(hasco_path))
else:
next_module = parse(path)
if ir == "llir":
metadata["name"] = get_kernel_name(next_module, pattern="define iluvatar_kernel void", llir=True)
if ir == "cubin":
asm[ir] = next_module
elif ir == "amdgcn":
asm[ir] = str(next_module[0])
else:
asm[ir] = str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
module = next_module
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# return handle to compiled kernel
return CompiledKernel(fn, so_path, metadata, asm)
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, fn, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
mod = importlib.util.module_from_spec(spec)
self.fn = fn
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.warp_size = metadata["warp_size"]
self.num_stages = metadata["num_stages"]
self.constants = metadata["constants"]
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
# because it involves doing runtime things
# (e.g., checking amount of shared memory on current device)
self.metadata = metadata
self.cu_module = None
self.cu_function = None
def _init_handles(self):
if self.cu_module is not None:
return
device = triton.runtime.jit.get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
self.n_spills = n_spills
self.n_regs = n_regs
self.cu_module = mod
self.cu_function = func
def __getattribute__(self, name):
if name == 'c_wrapper':
self._init_handles()
return super().__getattribute__(name)
def __getitem__(self, grid):
self._init_handles()
def runner(*args, stream=None):
if stream is None:
stream = triton.runtime.jit.get_cuda_stream()
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass