diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py index 52ace0dae..951393929 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -5,8 +5,10 @@ import triton import triton.language as tl from sglang.srt.lora.utils import LoRABatchInfo +from sglang.utils import cached_triton_kernel +@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @triton.jit def _chunked_lora_expand_kernel( # Pointers to matrices diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py index 5091ba09a..8b170bfa4 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -3,8 +3,10 @@ import triton import triton.language as tl from sglang.srt.lora.utils import LoRABatchInfo +from sglang.utils import cached_triton_kernel +@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @triton.jit def _chunked_lora_shrink_kernel( # Pointers to matrices diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 23849af54..91c3454a1 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,6 +1,8 @@ """Common utilities""" +import functools import importlib +import inspect import json import logging import os @@ -21,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import pybase64 import requests +import triton from IPython.display import HTML, display from pydantic import BaseModel from tqdm import tqdm @@ -540,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any: module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name) + + +class CachedKernel: + """ + Wrapper that allows kernel[grid](...) syntax with caching based on a key function. + + This wrapper caches compiled Triton kernels based on keys extracted by a + user-provided key function to avoid redundant compilations. + """ + + def __init__(self, fn, key_fn=None): + self.fn = fn + assert isinstance(fn, triton.runtime.jit.JITFunction) + + original_fn = fn.fn + self.signature = inspect.signature(original_fn) + self.param_names = tuple(self.signature.parameters.keys()) + self.num_args = len(self.param_names) + + # Check that no parameters have default values + for name, param in self.signature.parameters.items(): + assert ( + param.default is inspect.Parameter.empty + ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels." + + functools.update_wrapper(self, original_fn) + self.kernel_cache = {} + + # Store the key function + self.key_fn = key_fn + + def __getitem__(self, grid): + """ + Index with grid to get a launcher function. + Returns a launcher that will handle caching based on the key function. + """ + assert ( + isinstance(grid, tuple) and len(grid) <= 3 + ), "Grid must be a tuple with at most 3 dimensions." + + # Normalize grid once + if len(grid) < 3: + grid = grid + (1,) * (3 - len(grid)) + + def launcher(*args, **kwargs): + cache_key = self.key_fn(args, kwargs) + + cached_kernel = self.kernel_cache.get(cache_key) + + if cached_kernel is None: + # First time: compile and cache the kernel + cached_kernel = self.fn[grid](*args, **kwargs) + self.kernel_cache[cache_key] = cached_kernel + return cached_kernel + else: + # Use cached kernel + all_args = self._build_args(args, kwargs) + cached_kernel[grid](*all_args) + return cached_kernel + + return launcher + + def _build_args(self, args, kwargs): + """ + Build the complete argument list for kernel invocation. + """ + complete_args = list(args) + + for i in range(len(args), self.num_args): + name = self.param_names[i] + value = kwargs.get(name, inspect.Parameter.empty) + if value is not inspect.Parameter.empty: + complete_args.append(value) + else: + raise ValueError(f"Missing argument: {name}") + + return complete_args + + +def cached_triton_kernel(key_fn=None): + """ + Decorator that enables key-based caching for Triton kernels using a key function. + + It essentially bypasses Triton's built-in caching mechanism, allowing users to + define their own caching strategy based on kernel parameters. This helps reduce + the heavy overheads of Triton kernel launch when the kernel specialization dispatch + is simple. + + Usage: + @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024)) + @triton.jit + def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + ... + + # Invoke normally + my_kernel[grid](x, y, BLOCK_SIZE=1024) + + Args: + key_fn: A function that takes (args, kwargs) and returns the cache key(s). + The key can be a single value or a tuple of values. + + Returns: + A decorator that wraps the kernel with caching functionality. + + Note: Kernels with default parameter values are not supported and will raise an assertion error. + """ + + def decorator(fn): + return CachedKernel(fn, key_fn) + + return decorator