[4/4] Introduce CachedKernel to reduce CSGMV kernel launch overheads by 60% (#10709)
This commit is contained in:
@@ -5,8 +5,10 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
from sglang.utils import cached_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _chunked_lora_expand_kernel(
|
def _chunked_lora_expand_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
from sglang.utils import cached_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _chunked_lora_shrink_kernel(
|
def _chunked_lora_shrink_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Common utilities"""
|
"""Common utilities"""
|
||||||
|
|
||||||
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -21,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pybase64
|
import pybase64
|
||||||
import requests
|
import requests
|
||||||
|
import triton
|
||||||
from IPython.display import HTML, display
|
from IPython.display import HTML, display
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -540,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
|
|||||||
module_name, obj_name = qualname.rsplit(".", 1)
|
module_name, obj_name = qualname.rsplit(".", 1)
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
return getattr(module, obj_name)
|
return getattr(module, obj_name)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedKernel:
|
||||||
|
"""
|
||||||
|
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
||||||
|
|
||||||
|
This wrapper caches compiled Triton kernels based on keys extracted by a
|
||||||
|
user-provided key function to avoid redundant compilations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fn, key_fn=None):
|
||||||
|
self.fn = fn
|
||||||
|
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
||||||
|
|
||||||
|
original_fn = fn.fn
|
||||||
|
self.signature = inspect.signature(original_fn)
|
||||||
|
self.param_names = tuple(self.signature.parameters.keys())
|
||||||
|
self.num_args = len(self.param_names)
|
||||||
|
|
||||||
|
# Check that no parameters have default values
|
||||||
|
for name, param in self.signature.parameters.items():
|
||||||
|
assert (
|
||||||
|
param.default is inspect.Parameter.empty
|
||||||
|
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
||||||
|
|
||||||
|
functools.update_wrapper(self, original_fn)
|
||||||
|
self.kernel_cache = {}
|
||||||
|
|
||||||
|
# Store the key function
|
||||||
|
self.key_fn = key_fn
|
||||||
|
|
||||||
|
def __getitem__(self, grid):
|
||||||
|
"""
|
||||||
|
Index with grid to get a launcher function.
|
||||||
|
Returns a launcher that will handle caching based on the key function.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
isinstance(grid, tuple) and len(grid) <= 3
|
||||||
|
), "Grid must be a tuple with at most 3 dimensions."
|
||||||
|
|
||||||
|
# Normalize grid once
|
||||||
|
if len(grid) < 3:
|
||||||
|
grid = grid + (1,) * (3 - len(grid))
|
||||||
|
|
||||||
|
def launcher(*args, **kwargs):
|
||||||
|
cache_key = self.key_fn(args, kwargs)
|
||||||
|
|
||||||
|
cached_kernel = self.kernel_cache.get(cache_key)
|
||||||
|
|
||||||
|
if cached_kernel is None:
|
||||||
|
# First time: compile and cache the kernel
|
||||||
|
cached_kernel = self.fn[grid](*args, **kwargs)
|
||||||
|
self.kernel_cache[cache_key] = cached_kernel
|
||||||
|
return cached_kernel
|
||||||
|
else:
|
||||||
|
# Use cached kernel
|
||||||
|
all_args = self._build_args(args, kwargs)
|
||||||
|
cached_kernel[grid](*all_args)
|
||||||
|
return cached_kernel
|
||||||
|
|
||||||
|
return launcher
|
||||||
|
|
||||||
|
def _build_args(self, args, kwargs):
|
||||||
|
"""
|
||||||
|
Build the complete argument list for kernel invocation.
|
||||||
|
"""
|
||||||
|
complete_args = list(args)
|
||||||
|
|
||||||
|
for i in range(len(args), self.num_args):
|
||||||
|
name = self.param_names[i]
|
||||||
|
value = kwargs.get(name, inspect.Parameter.empty)
|
||||||
|
if value is not inspect.Parameter.empty:
|
||||||
|
complete_args.append(value)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Missing argument: {name}")
|
||||||
|
|
||||||
|
return complete_args
|
||||||
|
|
||||||
|
|
||||||
|
def cached_triton_kernel(key_fn=None):
|
||||||
|
"""
|
||||||
|
Decorator that enables key-based caching for Triton kernels using a key function.
|
||||||
|
|
||||||
|
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
||||||
|
define their own caching strategy based on kernel parameters. This helps reduce
|
||||||
|
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
||||||
|
is simple.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
||||||
|
@triton.jit
|
||||||
|
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
|
...
|
||||||
|
|
||||||
|
# Invoke normally
|
||||||
|
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
||||||
|
The key can be a single value or a tuple of values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A decorator that wraps the kernel with caching functionality.
|
||||||
|
|
||||||
|
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
return CachedKernel(fn, key_fn)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
Reference in New Issue
Block a user