Fix CI TestChunkedSGMV (#10737)

This commit is contained in:
Lifu Huang
2025-09-22 01:09:58 -07:00
committed by GitHub
parent 70e4b21853
commit 2101d93b4f
2 changed files with 17 additions and 0 deletions

View File

@@ -621,6 +621,12 @@ class CachedKernel:
return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None):
"""

View File

@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
)
from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel
from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel
from sglang.srt.lora.utils import LoRABatchInfo
CHUNK_SIZE = 16
def reset_kernel_cache():
_chunked_lora_shrink_kernel._clear_cache()
_chunked_lora_expand_kernel._clear_cache()
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matrix multiplication with mixed precision handling for float16"""
result = torch.matmul(a.float(), b.float())
@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase):
List[str],
]:
"""Create test batch with specified composition and mode"""
# Reset kernel cache to avoid cross-test contamination
reset_kernel_cache()
seq_lengths = self.generate_sequence_lengths(
batch_size, batch_mode, 1, self.max_seq_len
)