First commit
This commit is contained in:
122
pkgs/triton/tools/aot.py
Normal file
122
pkgs/triton/tools/aot.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import triton._C.libtriton.triton as libtriton
|
||||
import triton.compiler.compiler as tc
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('src', help="Source file to compile")
|
||||
parser.add_argument('--target', required=True,
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
|
||||
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
|
||||
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
|
||||
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: clean-up and re-use triton.compiler primitive functions
|
||||
# check for validity of format arguments
|
||||
if args.target not in VALID_FORMATS:
|
||||
print("Invalid target format: " + args.target)
|
||||
sys.exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = libtriton.ir.context()
|
||||
module = libtriton.ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = tc.optimize_ttir(module, arch=args.sm)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
if not args.num_warps:
|
||||
args.num_warps = 4
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
# auto detect available architecture and features
|
||||
# if nothing detected, set with default values
|
||||
arch_details = tc.get_amdgpu_arch_fulldetails()
|
||||
if not arch_details:
|
||||
arch_name = ""
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
arch_features = ""
|
||||
arch_warpsize = 64
|
||||
else:
|
||||
arch_triple, arch_name, arch_features, arch_warpsize = arch_details
|
||||
|
||||
# stop processing if architecture name is not automatically detected and is not set manually
|
||||
if not args.gfx and not arch_name:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
|
||||
# rewrite default and automatically detected values with manually provided data
|
||||
if args.gfx:
|
||||
arch_name = args.gfx
|
||||
if args.triple:
|
||||
arch_triple = args.triple
|
||||
if args.features:
|
||||
arch_features = args.features
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
# use compute_capability == 80
|
||||
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
|
||||
module = tc.optimize_ttgir(module, num_stages=3, arch=args.gfx)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = tc.ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
|
||||
# llvm-ir -> amdgcn asm, hsaco binary
|
||||
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
|
||||
print(hsaco_path)
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# set arch depending on platform
|
||||
if args.gfx:
|
||||
arch = args.gfx
|
||||
elif args.sm:
|
||||
arch = args.sm
|
||||
else:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=tc.CUDA_DEFAULT_WARP_SIZE)
|
||||
module = tc.optimize_ttgir(module, num_stages=3, arch=arch)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = tc.ttgir_to_llir(module, extern_libs=None, arch=arch)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# llvm-ir -> ptx
|
||||
if args.target == 'ptx':
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
module = tc.llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
if not args.gfx:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
|
||||
print(module)
|
||||
Reference in New Issue
Block a user