diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 91c3454a1..07d906440 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -621,6 +621,12 @@ class CachedKernel: return complete_args + def _clear_cache(self): + """ + Clear the kernel cache for testing purposes. + """ + self.kernel_cache.clear() + def cached_triton_kernel(key_fn=None): """ diff --git a/test/srt/lora/test_chunked_sgmv_backend.py b/test/srt/lora/test_chunked_sgmv_backend.py index 6df369f81..2cfde12db 100644 --- a/test/srt/lora/test_chunked_sgmv_backend.py +++ b/test/srt/lora/test_chunked_sgmv_backend.py @@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import ( chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, ) +from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel +from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel from sglang.srt.lora.utils import LoRABatchInfo CHUNK_SIZE = 16 +def reset_kernel_cache(): + _chunked_lora_shrink_kernel._clear_cache() + _chunked_lora_expand_kernel._clear_cache() + + def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Matrix multiplication with mixed precision handling for float16""" result = torch.matmul(a.float(), b.float()) @@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase): List[str], ]: """Create test batch with specified composition and mode""" + + # Reset kernel cache to avoid cross-test contamination + reset_kernel_cache() + seq_lengths = self.generate_sequence_lengths( batch_size, batch_mode, 1, self.max_seq_len )