Fix CI TestChunkedSGMV (#10737)
This commit is contained in:
@@ -621,6 +621,12 @@ class CachedKernel:
|
||||
|
||||
return complete_args
|
||||
|
||||
def _clear_cache(self):
|
||||
"""
|
||||
Clear the kernel cache for testing purposes.
|
||||
"""
|
||||
self.kernel_cache.clear()
|
||||
|
||||
|
||||
def cached_triton_kernel(key_fn=None):
|
||||
"""
|
||||
|
||||
@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import (
|
||||
chunked_sgmv_lora_expand_forward,
|
||||
chunked_sgmv_lora_shrink_forward,
|
||||
)
|
||||
from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel
|
||||
from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
CHUNK_SIZE = 16
|
||||
|
||||
|
||||
def reset_kernel_cache():
|
||||
_chunked_lora_shrink_kernel._clear_cache()
|
||||
_chunked_lora_expand_kernel._clear_cache()
|
||||
|
||||
|
||||
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Matrix multiplication with mixed precision handling for float16"""
|
||||
result = torch.matmul(a.float(), b.float())
|
||||
@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase):
|
||||
List[str],
|
||||
]:
|
||||
"""Create test batch with specified composition and mode"""
|
||||
|
||||
# Reset kernel cache to avoid cross-test contamination
|
||||
reset_kernel_cache()
|
||||
|
||||
seq_lengths = self.generate_sequence_lengths(
|
||||
batch_size, batch_mode, 1, self.max_seq_len
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user