Move cached kernel to srt.utils (#10776)
This commit is contained in:
@@ -5,7 +5,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.utils import cached_triton_kernel
|
||||
from sglang.srt.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
|
||||
@@ -3,7 +3,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.utils import cached_triton_kernel
|
||||
from sglang.srt.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
|
||||
@@ -22,6 +22,7 @@ import ctypes
|
||||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import ipaddress
|
||||
import itertools
|
||||
@@ -3224,3 +3225,120 @@ def get_extend_input_len_swa_limit(
|
||||
# and we can only free out-of-sliding-window kv indices after each prefill.
|
||||
# 3. page_size is because we want to have 1 token extra for generated tokens.
|
||||
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
|
||||
|
||||
|
||||
class CachedKernel:
|
||||
"""
|
||||
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
||||
|
||||
This wrapper caches compiled Triton kernels based on keys extracted by a
|
||||
user-provided key function to avoid redundant compilations.
|
||||
"""
|
||||
|
||||
def __init__(self, fn, key_fn=None):
|
||||
self.fn = fn
|
||||
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
||||
|
||||
original_fn = fn.fn
|
||||
self.signature = inspect.signature(original_fn)
|
||||
self.param_names = tuple(self.signature.parameters.keys())
|
||||
self.num_args = len(self.param_names)
|
||||
|
||||
# Check that no parameters have default values
|
||||
for name, param in self.signature.parameters.items():
|
||||
assert (
|
||||
param.default is inspect.Parameter.empty
|
||||
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
||||
|
||||
functools.update_wrapper(self, original_fn)
|
||||
self.kernel_cache = {}
|
||||
|
||||
# Store the key function
|
||||
self.key_fn = key_fn
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
Index with grid to get a launcher function.
|
||||
Returns a launcher that will handle caching based on the key function.
|
||||
"""
|
||||
assert (
|
||||
isinstance(grid, tuple) and len(grid) <= 3
|
||||
), "Grid must be a tuple with at most 3 dimensions."
|
||||
|
||||
# Normalize grid once
|
||||
if len(grid) < 3:
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
|
||||
def launcher(*args, **kwargs):
|
||||
cache_key = self.key_fn(args, kwargs)
|
||||
|
||||
cached_kernel = self.kernel_cache.get(cache_key)
|
||||
|
||||
if cached_kernel is None:
|
||||
# First time: compile and cache the kernel
|
||||
cached_kernel = self.fn[grid](*args, **kwargs)
|
||||
self.kernel_cache[cache_key] = cached_kernel
|
||||
return cached_kernel
|
||||
else:
|
||||
# Use cached kernel
|
||||
all_args = self._build_args(args, kwargs)
|
||||
cached_kernel[grid](*all_args)
|
||||
return cached_kernel
|
||||
|
||||
return launcher
|
||||
|
||||
def _build_args(self, args, kwargs):
|
||||
"""
|
||||
Build the complete argument list for kernel invocation.
|
||||
"""
|
||||
complete_args = list(args)
|
||||
|
||||
for i in range(len(args), self.num_args):
|
||||
name = self.param_names[i]
|
||||
value = kwargs.get(name, inspect.Parameter.empty)
|
||||
if value is not inspect.Parameter.empty:
|
||||
complete_args.append(value)
|
||||
else:
|
||||
raise ValueError(f"Missing argument: {name}")
|
||||
|
||||
return complete_args
|
||||
|
||||
def _clear_cache(self):
|
||||
"""
|
||||
Clear the kernel cache for testing purposes.
|
||||
"""
|
||||
self.kernel_cache.clear()
|
||||
|
||||
|
||||
def cached_triton_kernel(key_fn=None):
|
||||
"""
|
||||
Decorator that enables key-based caching for Triton kernels using a key function.
|
||||
|
||||
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
||||
define their own caching strategy based on kernel parameters. This helps reduce
|
||||
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
||||
is simple.
|
||||
|
||||
Usage:
|
||||
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
||||
@triton.jit
|
||||
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
...
|
||||
|
||||
# Invoke normally
|
||||
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
||||
|
||||
Args:
|
||||
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
||||
The key can be a single value or a tuple of values.
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the kernel with caching functionality.
|
||||
|
||||
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return CachedKernel(fn, key_fn)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Common utilities"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -24,7 +22,6 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||
import numpy as np
|
||||
import pybase64
|
||||
import requests
|
||||
import triton
|
||||
from IPython.display import HTML, display
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
@@ -552,120 +549,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
|
||||
class CachedKernel:
|
||||
"""
|
||||
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
||||
|
||||
This wrapper caches compiled Triton kernels based on keys extracted by a
|
||||
user-provided key function to avoid redundant compilations.
|
||||
"""
|
||||
|
||||
def __init__(self, fn, key_fn=None):
|
||||
self.fn = fn
|
||||
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
||||
|
||||
original_fn = fn.fn
|
||||
self.signature = inspect.signature(original_fn)
|
||||
self.param_names = tuple(self.signature.parameters.keys())
|
||||
self.num_args = len(self.param_names)
|
||||
|
||||
# Check that no parameters have default values
|
||||
for name, param in self.signature.parameters.items():
|
||||
assert (
|
||||
param.default is inspect.Parameter.empty
|
||||
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
||||
|
||||
functools.update_wrapper(self, original_fn)
|
||||
self.kernel_cache = {}
|
||||
|
||||
# Store the key function
|
||||
self.key_fn = key_fn
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
Index with grid to get a launcher function.
|
||||
Returns a launcher that will handle caching based on the key function.
|
||||
"""
|
||||
assert (
|
||||
isinstance(grid, tuple) and len(grid) <= 3
|
||||
), "Grid must be a tuple with at most 3 dimensions."
|
||||
|
||||
# Normalize grid once
|
||||
if len(grid) < 3:
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
|
||||
def launcher(*args, **kwargs):
|
||||
cache_key = self.key_fn(args, kwargs)
|
||||
|
||||
cached_kernel = self.kernel_cache.get(cache_key)
|
||||
|
||||
if cached_kernel is None:
|
||||
# First time: compile and cache the kernel
|
||||
cached_kernel = self.fn[grid](*args, **kwargs)
|
||||
self.kernel_cache[cache_key] = cached_kernel
|
||||
return cached_kernel
|
||||
else:
|
||||
# Use cached kernel
|
||||
all_args = self._build_args(args, kwargs)
|
||||
cached_kernel[grid](*all_args)
|
||||
return cached_kernel
|
||||
|
||||
return launcher
|
||||
|
||||
def _build_args(self, args, kwargs):
|
||||
"""
|
||||
Build the complete argument list for kernel invocation.
|
||||
"""
|
||||
complete_args = list(args)
|
||||
|
||||
for i in range(len(args), self.num_args):
|
||||
name = self.param_names[i]
|
||||
value = kwargs.get(name, inspect.Parameter.empty)
|
||||
if value is not inspect.Parameter.empty:
|
||||
complete_args.append(value)
|
||||
else:
|
||||
raise ValueError(f"Missing argument: {name}")
|
||||
|
||||
return complete_args
|
||||
|
||||
def _clear_cache(self):
|
||||
"""
|
||||
Clear the kernel cache for testing purposes.
|
||||
"""
|
||||
self.kernel_cache.clear()
|
||||
|
||||
|
||||
def cached_triton_kernel(key_fn=None):
|
||||
"""
|
||||
Decorator that enables key-based caching for Triton kernels using a key function.
|
||||
|
||||
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
||||
define their own caching strategy based on kernel parameters. This helps reduce
|
||||
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
||||
is simple.
|
||||
|
||||
Usage:
|
||||
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
||||
@triton.jit
|
||||
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
...
|
||||
|
||||
# Invoke normally
|
||||
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
||||
|
||||
Args:
|
||||
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
||||
The key can be a single value or a tuple of values.
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the kernel with caching functionality.
|
||||
|
||||
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return CachedKernel(fn, key_fn)
|
||||
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user