174 lines
5.7 KiB
Python
174 lines
5.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import ast
|
|
import copy
|
|
import functools
|
|
import linecache
|
|
import sys
|
|
from typing import Any, Dict, List
|
|
|
|
import triton
|
|
|
|
|
|
class _ForLoopUnroller(ast.NodeTransformer):
|
|
def __init__(self, target, inline_variables, loop_iter):
|
|
self.loop_iter = loop_iter
|
|
self.target = target
|
|
self.inline_variables = inline_variables
|
|
|
|
def visit_Name(self, node):
|
|
if node.id != self.target:
|
|
return node
|
|
return ast.Name(str(self.loop_iter))
|
|
|
|
def visit_Subscript(self, node):
|
|
# Pattern-matching `value[slice]`
|
|
if (
|
|
isinstance(node.slice, ast.Name)
|
|
and node.slice.id == self.target
|
|
and isinstance(node.value, ast.Name)
|
|
and node.value.id in self.inline_variables
|
|
):
|
|
return ast.Name(f"{node.value.id}{self.loop_iter}")
|
|
return node
|
|
|
|
|
|
class _VisitorUnrollKernel(ast.NodeTransformer):
|
|
def __init__(self, N):
|
|
self.inline_variables = set()
|
|
self.N = N
|
|
|
|
def visit_AnnAssign(self, node):
|
|
# Pattern-matching:
|
|
# var_name: "VAR_ARGS_ARRAY"
|
|
if (
|
|
node.value is None
|
|
and node.simple == 1
|
|
and isinstance(node.target, ast.Name)
|
|
and isinstance(node.annotation, ast.Constant)
|
|
and node.annotation.value == "VAR_ARGS_ARRAY"
|
|
):
|
|
self.inline_variables.add(node.target.id)
|
|
return []
|
|
if node.value is not None:
|
|
node.value = self.visit(node.value)
|
|
if node.annotation is not None:
|
|
node.annotation = self.visit(node.annotation)
|
|
if node.target is not None:
|
|
node.target = self.visit(node.target)
|
|
return node
|
|
|
|
def visit_arguments(self, node):
|
|
# Replace `args` annotated with `VAR_ARGS_ARRAY`
|
|
new_args = []
|
|
for arg in node.args:
|
|
if (
|
|
arg.annotation is not None
|
|
and isinstance(arg.annotation, ast.Constant)
|
|
and arg.annotation.value == "VAR_ARGS_ARRAY"
|
|
):
|
|
self.inline_variables.add(arg.arg)
|
|
new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
|
|
continue
|
|
new_args.append(arg)
|
|
if node.vararg is not None:
|
|
self.inline_variables.add(node.vararg.arg)
|
|
new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
|
|
node.vararg = None
|
|
new_args += node.kwonlyargs
|
|
node.kwonlyargs = []
|
|
node.args = new_args
|
|
return node
|
|
|
|
def visit_For(self, node):
|
|
if (
|
|
not isinstance(node.iter, ast.Call)
|
|
or node.iter.func.id != "range"
|
|
or len(node.iter.args) != 1
|
|
or not isinstance(node.iter.args[0], ast.Call)
|
|
or node.iter.args[0].func.id != "len"
|
|
or len(node.iter.args[0].args) != 1
|
|
or node.iter.args[0].args[0].id not in self.inline_variables
|
|
):
|
|
node.body = [self.visit(x) for x in node.body]
|
|
return node
|
|
# We know we have to modify this loop
|
|
new_nodes = []
|
|
for i in range(self.N):
|
|
unroller = _ForLoopUnroller(
|
|
target=node.target.id,
|
|
inline_variables=self.inline_variables,
|
|
loop_iter=i,
|
|
)
|
|
for body in node.body:
|
|
body = copy.deepcopy(body)
|
|
new_node = ast.fix_missing_locations(unroller.visit(body))
|
|
new_node = self.visit(new_node)
|
|
new_nodes.append(new_node)
|
|
return new_nodes
|
|
|
|
|
|
# Hackfix to get access to get source-code for
|
|
# `exec`-created functions - see https://stackoverflow.com/a/69668999
|
|
_getlines_orig = None
|
|
_FILENAME_TO_SRC: Dict[str, str] = {}
|
|
|
|
|
|
def _monkey_patched_getlines(filename, module_globals=None):
|
|
if filename in _FILENAME_TO_SRC:
|
|
return _FILENAME_TO_SRC[filename]
|
|
else:
|
|
return _getlines_orig(filename, module_globals) # type: ignore
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def unroll_varargs(kernel, N: int):
|
|
"""
|
|
Specializes a triton kernel with variable number of inputs
|
|
to a specific number of inputs `N`.
|
|
NOTE: Because it's quite costly to call `triton.jit`,
|
|
we cache the returned value with `lru_cache`
|
|
"""
|
|
global _FILENAME_TO_SRC, _getlines_orig
|
|
|
|
k = triton.JITFunction(kernel.fn)
|
|
parsed = ast.parse(k.src)
|
|
nodeVisitor = _VisitorUnrollKernel(N=N)
|
|
parsed = nodeVisitor.visit(parsed)
|
|
parsed = ast.fix_missing_locations(parsed)
|
|
|
|
# NOTE: `ast.unparse` requires python 3.9+
|
|
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
|
|
raise RuntimeError("Error: This functionality requires python 3.9 or above")
|
|
new_src = ast.unparse(parsed) # type: ignore
|
|
|
|
# Now we want to `eval` the function, but we need all this
|
|
# boilerplate code to make sure triton can run `inspect.getsource`
|
|
|
|
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"
|
|
|
|
# Create function given source
|
|
code = compile(new_src, fn_filename, "exec")
|
|
|
|
_locals: Dict[str, Any] = {}
|
|
exec(code, kernel.fn.__globals__, _locals)
|
|
assert len(_locals) == 1, len(_locals)
|
|
fn = next(iter(_locals.values()))
|
|
# Patch `getlines` only the first time
|
|
if not _FILENAME_TO_SRC:
|
|
_getlines_orig = linecache.getlines
|
|
linecache.getlines = _monkey_patched_getlines
|
|
_FILENAME_TO_SRC[fn_filename] = new_src
|
|
|
|
jitted_fn = triton.jit(fn)
|
|
jitted_fn.src = new_src
|
|
return jitted_fn
|
|
|
|
|
|
# Note: just import this to make mypy happy
|
|
# when annotating variables with `VAR_ARGS_ARRAY`
|
|
VAR_ARGS_ARRAY = List[Any]
|