First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
from vllm.triton_utils.importing import HAS_TRITON
__all__ = ["HAS_TRITON"]
#from vllm.triton_utils.custom_cache_manager import (
# maybe_set_triton_cache_manager)
#from vllm.triton_utils.libentry import libentry
__all__ += ["maybe_set_triton_cache_manager", "libentry"]

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,53 @@
import os
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)
from vllm.logger import init_logger
logger = init_logger(__name__)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager
class CustomCacheManager(FileCacheManager):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")

View File

@@ -0,0 +1,11 @@
from importlib.util import find_spec
from vllm.logger import init_logger
logger = init_logger(__name__)
HAS_TRITON = False
if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions"
" will not be available.")

View File

@@ -0,0 +1,167 @@
# Copied From https://github.com/FlagOpen/FlagGems
import inspect
import triton
class LibEntry(triton.KernelInterface):
def __init__(
self,
fn,
):
self.fn = fn
self.arg_names = fn.arg_names
self.divisibility = 16
self.kernel_cache = dict()
fn = self.fn
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and not p.do_not_specialize
]
self.do_not_specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and p.do_not_specialize
]
def key(self, spec_args, dns_args, const_args):
spec_key = [(arg.dtype, arg.data_ptr() %
self.divisibility == 0) if hasattr(arg, "data_ptr") else
(type(arg), arg) for arg in spec_args]
dns_key = [
arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args
]
# const args passed by position
return tuple(spec_key + dns_key + const_args)
def run(self, *args, **kwargs):
grid = kwargs["grid"]
# collect all the arguments
spec_args = [] # specialize arguments
dns_args = [] # do not specialize arguments
const_args = [] # constexpr arguments
k_args = [] # kernel arguments
for i, arg in enumerate(args):
if i in self.specialize_indices:
k_args.append(arg)
spec_args.append(arg)
elif i in self.do_not_specialize_indices:
k_args.append(arg)
dns_args.append(arg)
else:
const_args.append(arg)
for p in self.jit_function.params[len(args):]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default
if p.is_constexpr:
const_args.append(val)
elif p.do_not_specialize:
dns_args.append(val)
k_args.append(val)
else:
spec_args.append(val)
k_args.append(val)
entry_key = self.key(spec_args, dns_args, const_args)
if entry_key not in self.kernel_cache:
# compile the kernel also completes the related computations
kernel = self.fn.run(*args, **kwargs)
fn = self.fn
# collect constexpr arguments for grid computation
constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
constexprs[v] = heur({
**dict(zip(fn.arg_names, args)),
**kwargs,
**constexprs,
})
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
# In vLLM, certain kernels like fused_moe_kernel get the
# best_config(as kwargs) from a configuration json file, rather
# than using Autotuner & Heuristics. Therefore, all their constexprs
# (tl.constexpr) are assigned values through the following loop.
for p in self.jit_function.params:
if p.is_constexpr and p.name not in constexprs:
constexprs[p.name] = p.default #default=inspect._empty
self.kernel_cache[entry_key] = (kernel, constexprs)
else:
# load kernel from cache directly
kernel, constexprs = self.kernel_cache[entry_key]
if callable(grid):
# collect all arguments to the grid fnie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from
# Autotunner & Heuristics when kwargs & captured args conflict,
# captured args have higher priority
# 4. We must filter out captured args with default value firstly
constexprs = {
k: v
for k, v in constexprs.items() if v is not inspect._empty
}
meta = {
**dict(zip(self.arg_names, args)),
**kwargs,
**constexprs,
}
grid = grid(meta)
if isinstance(grid, tuple):
grid = grid + (1, 1)
elif isinstance(grid, list):
grid = grid + [1, 1]
kernel[grid[0:3]](*k_args)
# maintaining the same return type as the JITFunction.run
return kernel
def libentry():
"""
Decorator for triton library entries.
Motivation:
The runtime overhead of Triton kernels is the reason for the lower
performance of small kernels, particularly evident with smaller models.
Using this decorator can reduce Triton runtime overhead.
How:
The `run` function of JITFunction needs to accomplish:
- Parameter binding using inspect
- KernelArg type wrapping
- Cache key calculation
When dealing with small size, these steps can become bottlenecks in
Triton runtime. Libentry simplifies these steps to reduce runtime
overhead, thereby improving the runtime expenses of small kernels.
NOTE:
When Triton is upgraded to version 3.0.0, libentry can be removed,
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
"""
def decorator(fn):
return LibEntry(fn)
return decorator