First commit
This commit is contained in:
173
pkgs/xformers/triton/vararg_kernel.py
Normal file
173
pkgs/xformers/triton/vararg_kernel.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import functools
|
||||
import linecache
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class _ForLoopUnroller(ast.NodeTransformer):
|
||||
def __init__(self, target, inline_variables, loop_iter):
|
||||
self.loop_iter = loop_iter
|
||||
self.target = target
|
||||
self.inline_variables = inline_variables
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id != self.target:
|
||||
return node
|
||||
return ast.Name(str(self.loop_iter))
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
# Pattern-matching `value[slice]`
|
||||
if (
|
||||
isinstance(node.slice, ast.Name)
|
||||
and node.slice.id == self.target
|
||||
and isinstance(node.value, ast.Name)
|
||||
and node.value.id in self.inline_variables
|
||||
):
|
||||
return ast.Name(f"{node.value.id}{self.loop_iter}")
|
||||
return node
|
||||
|
||||
|
||||
class _VisitorUnrollKernel(ast.NodeTransformer):
|
||||
def __init__(self, N):
|
||||
self.inline_variables = set()
|
||||
self.N = N
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
# Pattern-matching:
|
||||
# var_name: "VAR_ARGS_ARRAY"
|
||||
if (
|
||||
node.value is None
|
||||
and node.simple == 1
|
||||
and isinstance(node.target, ast.Name)
|
||||
and isinstance(node.annotation, ast.Constant)
|
||||
and node.annotation.value == "VAR_ARGS_ARRAY"
|
||||
):
|
||||
self.inline_variables.add(node.target.id)
|
||||
return []
|
||||
if node.value is not None:
|
||||
node.value = self.visit(node.value)
|
||||
if node.annotation is not None:
|
||||
node.annotation = self.visit(node.annotation)
|
||||
if node.target is not None:
|
||||
node.target = self.visit(node.target)
|
||||
return node
|
||||
|
||||
def visit_arguments(self, node):
|
||||
# Replace `args` annotated with `VAR_ARGS_ARRAY`
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if (
|
||||
arg.annotation is not None
|
||||
and isinstance(arg.annotation, ast.Constant)
|
||||
and arg.annotation.value == "VAR_ARGS_ARRAY"
|
||||
):
|
||||
self.inline_variables.add(arg.arg)
|
||||
new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
|
||||
continue
|
||||
new_args.append(arg)
|
||||
if node.vararg is not None:
|
||||
self.inline_variables.add(node.vararg.arg)
|
||||
new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
|
||||
node.vararg = None
|
||||
new_args += node.kwonlyargs
|
||||
node.kwonlyargs = []
|
||||
node.args = new_args
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
if (
|
||||
not isinstance(node.iter, ast.Call)
|
||||
or node.iter.func.id != "range"
|
||||
or len(node.iter.args) != 1
|
||||
or not isinstance(node.iter.args[0], ast.Call)
|
||||
or node.iter.args[0].func.id != "len"
|
||||
or len(node.iter.args[0].args) != 1
|
||||
or node.iter.args[0].args[0].id not in self.inline_variables
|
||||
):
|
||||
node.body = [self.visit(x) for x in node.body]
|
||||
return node
|
||||
# We know we have to modify this loop
|
||||
new_nodes = []
|
||||
for i in range(self.N):
|
||||
unroller = _ForLoopUnroller(
|
||||
target=node.target.id,
|
||||
inline_variables=self.inline_variables,
|
||||
loop_iter=i,
|
||||
)
|
||||
for body in node.body:
|
||||
body = copy.deepcopy(body)
|
||||
new_node = ast.fix_missing_locations(unroller.visit(body))
|
||||
new_node = self.visit(new_node)
|
||||
new_nodes.append(new_node)
|
||||
return new_nodes
|
||||
|
||||
|
||||
# Hackfix to get access to get source-code for
|
||||
# `exec`-created functions - see https://stackoverflow.com/a/69668999
|
||||
_getlines_orig = None
|
||||
_FILENAME_TO_SRC: Dict[str, str] = {}
|
||||
|
||||
|
||||
def _monkey_patched_getlines(filename, module_globals=None):
|
||||
if filename in _FILENAME_TO_SRC:
|
||||
return _FILENAME_TO_SRC[filename]
|
||||
else:
|
||||
return _getlines_orig(filename, module_globals) # type: ignore
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def unroll_varargs(kernel, N: int):
|
||||
"""
|
||||
Specializes a triton kernel with variable number of inputs
|
||||
to a specific number of inputs `N`.
|
||||
NOTE: Because it's quite costly to call `triton.jit`,
|
||||
we cache the returned value with `lru_cache`
|
||||
"""
|
||||
global _FILENAME_TO_SRC, _getlines_orig
|
||||
|
||||
k = triton.JITFunction(kernel.fn)
|
||||
parsed = ast.parse(k.src)
|
||||
nodeVisitor = _VisitorUnrollKernel(N=N)
|
||||
parsed = nodeVisitor.visit(parsed)
|
||||
parsed = ast.fix_missing_locations(parsed)
|
||||
|
||||
# NOTE: `ast.unparse` requires python 3.9+
|
||||
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
|
||||
raise RuntimeError("Error: This functionality requires python 3.9 or above")
|
||||
new_src = ast.unparse(parsed) # type: ignore
|
||||
|
||||
# Now we want to `eval` the function, but we need all this
|
||||
# boilerplate code to make sure triton can run `inspect.getsource`
|
||||
|
||||
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"
|
||||
|
||||
# Create function given source
|
||||
code = compile(new_src, fn_filename, "exec")
|
||||
|
||||
_locals: Dict[str, Any] = {}
|
||||
exec(code, kernel.fn.__globals__, _locals)
|
||||
assert len(_locals) == 1, len(_locals)
|
||||
fn = next(iter(_locals.values()))
|
||||
# Patch `getlines` only the first time
|
||||
if not _FILENAME_TO_SRC:
|
||||
_getlines_orig = linecache.getlines
|
||||
linecache.getlines = _monkey_patched_getlines
|
||||
_FILENAME_TO_SRC[fn_filename] = new_src
|
||||
|
||||
jitted_fn = triton.jit(fn)
|
||||
jitted_fn.src = new_src
|
||||
return jitted_fn
|
||||
|
||||
|
||||
# Note: just import this to make mypy happy
|
||||
# when annotating variables with `VAR_ARGS_ARRAY`
|
||||
VAR_ARGS_ARRAY = List[Any]
|
||||
Reference in New Issue
Block a user