diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py index 951393929..1767c5ee4 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -5,7 +5,7 @@ import triton import triton.language as tl from sglang.srt.lora.utils import LoRABatchInfo -from sglang.utils import cached_triton_kernel +from sglang.srt.utils import cached_triton_kernel @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py index 8b170bfa4..e0ef41fb7 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -3,7 +3,7 @@ import triton import triton.language as tl from sglang.srt.lora.utils import LoRABatchInfo -from sglang.utils import cached_triton_kernel +from sglang.srt.utils import cached_triton_kernel @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 812d72a08..dd01c67b7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -22,6 +22,7 @@ import ctypes import dataclasses import functools import importlib +import inspect import io import ipaddress import itertools @@ -3224,3 +3225,120 @@ def get_extend_input_len_swa_limit( # and we can only free out-of-sliding-window kv indices after each prefill. # 3. page_size is because we want to have 1 token extra for generated tokens. return page_size + 2 * max(sliding_window_size, chunked_prefill_size) + + +class CachedKernel: + """ + Wrapper that allows kernel[grid](...) syntax with caching based on a key function. + + This wrapper caches compiled Triton kernels based on keys extracted by a + user-provided key function to avoid redundant compilations. + """ + + def __init__(self, fn, key_fn=None): + self.fn = fn + assert isinstance(fn, triton.runtime.jit.JITFunction) + + original_fn = fn.fn + self.signature = inspect.signature(original_fn) + self.param_names = tuple(self.signature.parameters.keys()) + self.num_args = len(self.param_names) + + # Check that no parameters have default values + for name, param in self.signature.parameters.items(): + assert ( + param.default is inspect.Parameter.empty + ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels." + + functools.update_wrapper(self, original_fn) + self.kernel_cache = {} + + # Store the key function + self.key_fn = key_fn + + def __getitem__(self, grid): + """ + Index with grid to get a launcher function. + Returns a launcher that will handle caching based on the key function. + """ + assert ( + isinstance(grid, tuple) and len(grid) <= 3 + ), "Grid must be a tuple with at most 3 dimensions." + + # Normalize grid once + if len(grid) < 3: + grid = grid + (1,) * (3 - len(grid)) + + def launcher(*args, **kwargs): + cache_key = self.key_fn(args, kwargs) + + cached_kernel = self.kernel_cache.get(cache_key) + + if cached_kernel is None: + # First time: compile and cache the kernel + cached_kernel = self.fn[grid](*args, **kwargs) + self.kernel_cache[cache_key] = cached_kernel + return cached_kernel + else: + # Use cached kernel + all_args = self._build_args(args, kwargs) + cached_kernel[grid](*all_args) + return cached_kernel + + return launcher + + def _build_args(self, args, kwargs): + """ + Build the complete argument list for kernel invocation. + """ + complete_args = list(args) + + for i in range(len(args), self.num_args): + name = self.param_names[i] + value = kwargs.get(name, inspect.Parameter.empty) + if value is not inspect.Parameter.empty: + complete_args.append(value) + else: + raise ValueError(f"Missing argument: {name}") + + return complete_args + + def _clear_cache(self): + """ + Clear the kernel cache for testing purposes. + """ + self.kernel_cache.clear() + + +def cached_triton_kernel(key_fn=None): + """ + Decorator that enables key-based caching for Triton kernels using a key function. + + It essentially bypasses Triton's built-in caching mechanism, allowing users to + define their own caching strategy based on kernel parameters. This helps reduce + the heavy overheads of Triton kernel launch when the kernel specialization dispatch + is simple. + + Usage: + @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024)) + @triton.jit + def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + ... + + # Invoke normally + my_kernel[grid](x, y, BLOCK_SIZE=1024) + + Args: + key_fn: A function that takes (args, kwargs) and returns the cache key(s). + The key can be a single value or a tuple of values. + + Returns: + A decorator that wraps the kernel with caching functionality. + + Note: Kernels with default parameter values are not supported and will raise an assertion error. + """ + + def decorator(fn): + return CachedKernel(fn, key_fn) + + return decorator diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 15d93774b..1d62c5df8 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,8 +1,6 @@ """Common utilities""" -import functools import importlib -import inspect import json import logging import os @@ -24,7 +22,6 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import pybase64 import requests -import triton from IPython.display import HTML, display from pydantic import BaseModel from tqdm import tqdm @@ -552,120 +549,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any: module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name) - - -class CachedKernel: - """ - Wrapper that allows kernel[grid](...) syntax with caching based on a key function. - - This wrapper caches compiled Triton kernels based on keys extracted by a - user-provided key function to avoid redundant compilations. - """ - - def __init__(self, fn, key_fn=None): - self.fn = fn - assert isinstance(fn, triton.runtime.jit.JITFunction) - - original_fn = fn.fn - self.signature = inspect.signature(original_fn) - self.param_names = tuple(self.signature.parameters.keys()) - self.num_args = len(self.param_names) - - # Check that no parameters have default values - for name, param in self.signature.parameters.items(): - assert ( - param.default is inspect.Parameter.empty - ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels." - - functools.update_wrapper(self, original_fn) - self.kernel_cache = {} - - # Store the key function - self.key_fn = key_fn - - def __getitem__(self, grid): - """ - Index with grid to get a launcher function. - Returns a launcher that will handle caching based on the key function. - """ - assert ( - isinstance(grid, tuple) and len(grid) <= 3 - ), "Grid must be a tuple with at most 3 dimensions." - - # Normalize grid once - if len(grid) < 3: - grid = grid + (1,) * (3 - len(grid)) - - def launcher(*args, **kwargs): - cache_key = self.key_fn(args, kwargs) - - cached_kernel = self.kernel_cache.get(cache_key) - - if cached_kernel is None: - # First time: compile and cache the kernel - cached_kernel = self.fn[grid](*args, **kwargs) - self.kernel_cache[cache_key] = cached_kernel - return cached_kernel - else: - # Use cached kernel - all_args = self._build_args(args, kwargs) - cached_kernel[grid](*all_args) - return cached_kernel - - return launcher - - def _build_args(self, args, kwargs): - """ - Build the complete argument list for kernel invocation. - """ - complete_args = list(args) - - for i in range(len(args), self.num_args): - name = self.param_names[i] - value = kwargs.get(name, inspect.Parameter.empty) - if value is not inspect.Parameter.empty: - complete_args.append(value) - else: - raise ValueError(f"Missing argument: {name}") - - return complete_args - - def _clear_cache(self): - """ - Clear the kernel cache for testing purposes. - """ - self.kernel_cache.clear() - - -def cached_triton_kernel(key_fn=None): - """ - Decorator that enables key-based caching for Triton kernels using a key function. - - It essentially bypasses Triton's built-in caching mechanism, allowing users to - define their own caching strategy based on kernel parameters. This helps reduce - the heavy overheads of Triton kernel launch when the kernel specialization dispatch - is simple. - - Usage: - @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024)) - @triton.jit - def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): - ... - - # Invoke normally - my_kernel[grid](x, y, BLOCK_SIZE=1024) - - Args: - key_fn: A function that takes (args, kwargs) and returns the cache key(s). - The key can be a single value or a tuple of values. - - Returns: - A decorator that wraps the kernel with caching functionality. - - Note: Kernels with default parameter values are not supported and will raise an assertion error. - """ - - def decorator(fn): - return CachedKernel(fn, key_fn) - - return decorator