First commit
This commit is contained in:
9
vllm/triton_utils/__init__.py
Normal file
9
vllm/triton_utils/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
|
||||
__all__ = ["HAS_TRITON"]
|
||||
|
||||
#from vllm.triton_utils.custom_cache_manager import (
|
||||
# maybe_set_triton_cache_manager)
|
||||
#from vllm.triton_utils.libentry import libentry
|
||||
|
||||
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
|
||||
BIN
vllm/triton_utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/triton_utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/triton_utils/__pycache__/importing.cpython-310.pyc
Normal file
BIN
vllm/triton_utils/__pycache__/importing.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/triton_utils/__pycache__/libentry.cpython-310.pyc
Normal file
BIN
vllm/triton_utils/__pycache__/libentry.cpython-310.pyc
Normal file
Binary file not shown.
53
vllm/triton_utils/custom_cache_manager.py
Normal file
53
vllm/triton_utils/custom_cache_manager.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
|
||||
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
|
||||
default_dump_dir, default_override_dir)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_set_triton_cache_manager() -> None:
|
||||
"""Set environment variable to tell Triton to use a
|
||||
custom cache manager"""
|
||||
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||
if cache_manger is None:
|
||||
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
|
||||
logger.info("Setting Triton cache manager to: %s", manager)
|
||||
os.environ["TRITON_CACHE_MANAGER"] = manager
|
||||
|
||||
|
||||
class CustomCacheManager(FileCacheManager):
|
||||
"""Re-implements Triton's cache manager, ensuring that a
|
||||
unique cache directory is created for each process. This is
|
||||
needed to avoid collisions when running with tp>1 and
|
||||
using multi-processing as the distributed backend.
|
||||
|
||||
Note this issue was fixed by triton-lang/triton/pull/4295,
|
||||
but the fix is not yet included in triton==v3.0.0. However,
|
||||
it should be included in the subsequent version.
|
||||
"""
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
|
||||
"").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
11
vllm/triton_utils/importing.py
Normal file
11
vllm/triton_utils/importing.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
HAS_TRITON = False
|
||||
|
||||
if not HAS_TRITON:
|
||||
logger.info("Triton not installed; certain GPU-related functions"
|
||||
" will not be available.")
|
||||
167
vllm/triton_utils/libentry.py
Normal file
167
vllm/triton_utils/libentry.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copied From https://github.com/FlagOpen/FlagGems
|
||||
|
||||
import inspect
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class LibEntry(triton.KernelInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
):
|
||||
self.fn = fn
|
||||
self.arg_names = fn.arg_names
|
||||
self.divisibility = 16
|
||||
self.kernel_cache = dict()
|
||||
fn = self.fn
|
||||
while not isinstance(fn, triton.runtime.JITFunction):
|
||||
fn = fn.fn
|
||||
self.jit_function: triton.runtime.JITFunction = fn
|
||||
self.specialize_indices = [
|
||||
p.num for p in self.jit_function.params
|
||||
if not p.is_constexpr and not p.do_not_specialize
|
||||
]
|
||||
self.do_not_specialize_indices = [
|
||||
p.num for p in self.jit_function.params
|
||||
if not p.is_constexpr and p.do_not_specialize
|
||||
]
|
||||
|
||||
def key(self, spec_args, dns_args, const_args):
|
||||
spec_key = [(arg.dtype, arg.data_ptr() %
|
||||
self.divisibility == 0) if hasattr(arg, "data_ptr") else
|
||||
(type(arg), arg) for arg in spec_args]
|
||||
dns_key = [
|
||||
arg.dtype if hasattr(
|
||||
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
|
||||
else "i32" if arg >= -(2**31) and arg <= 2**31 -
|
||||
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
|
||||
for arg in dns_args
|
||||
]
|
||||
# const args passed by position
|
||||
return tuple(spec_key + dns_key + const_args)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
grid = kwargs["grid"]
|
||||
# collect all the arguments
|
||||
spec_args = [] # specialize arguments
|
||||
dns_args = [] # do not specialize arguments
|
||||
const_args = [] # constexpr arguments
|
||||
k_args = [] # kernel arguments
|
||||
for i, arg in enumerate(args):
|
||||
if i in self.specialize_indices:
|
||||
k_args.append(arg)
|
||||
spec_args.append(arg)
|
||||
elif i in self.do_not_specialize_indices:
|
||||
k_args.append(arg)
|
||||
dns_args.append(arg)
|
||||
else:
|
||||
const_args.append(arg)
|
||||
for p in self.jit_function.params[len(args):]:
|
||||
if p.name in kwargs:
|
||||
val = kwargs[p.name]
|
||||
elif p.default is inspect._empty:
|
||||
continue
|
||||
else:
|
||||
val = p.default
|
||||
|
||||
if p.is_constexpr:
|
||||
const_args.append(val)
|
||||
elif p.do_not_specialize:
|
||||
dns_args.append(val)
|
||||
k_args.append(val)
|
||||
else:
|
||||
spec_args.append(val)
|
||||
k_args.append(val)
|
||||
|
||||
entry_key = self.key(spec_args, dns_args, const_args)
|
||||
|
||||
if entry_key not in self.kernel_cache:
|
||||
# compile the kernel also completes the related computations
|
||||
kernel = self.fn.run(*args, **kwargs)
|
||||
fn = self.fn
|
||||
# collect constexpr arguments for grid computation
|
||||
constexprs = {}
|
||||
while not isinstance(fn, triton.runtime.JITFunction):
|
||||
if isinstance(fn, triton.runtime.Autotuner):
|
||||
config = fn.best_config
|
||||
constexprs["num_warps"] = config.num_warps
|
||||
constexprs["num_stages"] = config.num_stages
|
||||
constexprs["num_ctas"] = config.num_ctas
|
||||
constexprs = {**constexprs, **config.kwargs}
|
||||
elif isinstance(fn, triton.runtime.Heuristics):
|
||||
for v, heur in fn.values.items():
|
||||
constexprs[v] = heur({
|
||||
**dict(zip(fn.arg_names, args)),
|
||||
**kwargs,
|
||||
**constexprs,
|
||||
})
|
||||
else:
|
||||
raise RuntimeError("Invalid Runtime Function")
|
||||
fn = fn.fn
|
||||
# In vLLM, certain kernels like fused_moe_kernel get the
|
||||
# best_config(as kwargs) from a configuration json file, rather
|
||||
# than using Autotuner & Heuristics. Therefore, all their constexprs
|
||||
# (tl.constexpr) are assigned values through the following loop.
|
||||
for p in self.jit_function.params:
|
||||
if p.is_constexpr and p.name not in constexprs:
|
||||
constexprs[p.name] = p.default #default=inspect._empty
|
||||
self.kernel_cache[entry_key] = (kernel, constexprs)
|
||||
else:
|
||||
# load kernel from cache directly
|
||||
kernel, constexprs = self.kernel_cache[entry_key]
|
||||
|
||||
if callable(grid):
|
||||
# collect all arguments to the grid fn,ie:
|
||||
# 1. args,
|
||||
# 2. kwargs,
|
||||
# 3. all all other captured arguments in CompiledKernel from
|
||||
# Autotunner & Heuristics when kwargs & captured args conflict,
|
||||
# captured args have higher priority
|
||||
# 4. We must filter out captured args with default value firstly
|
||||
constexprs = {
|
||||
k: v
|
||||
for k, v in constexprs.items() if v is not inspect._empty
|
||||
}
|
||||
meta = {
|
||||
**dict(zip(self.arg_names, args)),
|
||||
**kwargs,
|
||||
**constexprs,
|
||||
}
|
||||
grid = grid(meta)
|
||||
if isinstance(grid, tuple):
|
||||
grid = grid + (1, 1)
|
||||
elif isinstance(grid, list):
|
||||
grid = grid + [1, 1]
|
||||
kernel[grid[0:3]](*k_args)
|
||||
# maintaining the same return type as the JITFunction.run
|
||||
return kernel
|
||||
|
||||
|
||||
def libentry():
|
||||
"""
|
||||
Decorator for triton library entries.
|
||||
Motivation:
|
||||
The runtime overhead of Triton kernels is the reason for the lower
|
||||
performance of small kernels, particularly evident with smaller models.
|
||||
Using this decorator can reduce Triton runtime overhead.
|
||||
How:
|
||||
The `run` function of JITFunction needs to accomplish:
|
||||
- Parameter binding using inspect
|
||||
- KernelArg type wrapping
|
||||
- Cache key calculation
|
||||
When dealing with small size, these steps can become bottlenecks in
|
||||
Triton runtime. Libentry simplifies these steps to reduce runtime
|
||||
overhead, thereby improving the runtime expenses of small kernels.
|
||||
NOTE:
|
||||
When Triton is upgraded to version 3.0.0, libentry can be removed,
|
||||
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return LibEntry(fn)
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user