Fix CI TestChunkedSGMV (#10737)
This commit is contained in:
@@ -621,6 +621,12 @@ class CachedKernel:
|
|||||||
|
|
||||||
return complete_args
|
return complete_args
|
||||||
|
|
||||||
|
def _clear_cache(self):
|
||||||
|
"""
|
||||||
|
Clear the kernel cache for testing purposes.
|
||||||
|
"""
|
||||||
|
self.kernel_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def cached_triton_kernel(key_fn=None):
|
def cached_triton_kernel(key_fn=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import (
|
|||||||
chunked_sgmv_lora_expand_forward,
|
chunked_sgmv_lora_expand_forward,
|
||||||
chunked_sgmv_lora_shrink_forward,
|
chunked_sgmv_lora_shrink_forward,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel
|
||||||
|
from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
CHUNK_SIZE = 16
|
CHUNK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
def reset_kernel_cache():
|
||||||
|
_chunked_lora_shrink_kernel._clear_cache()
|
||||||
|
_chunked_lora_expand_kernel._clear_cache()
|
||||||
|
|
||||||
|
|
||||||
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
"""Matrix multiplication with mixed precision handling for float16"""
|
"""Matrix multiplication with mixed precision handling for float16"""
|
||||||
result = torch.matmul(a.float(), b.float())
|
result = torch.matmul(a.float(), b.float())
|
||||||
@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase):
|
|||||||
List[str],
|
List[str],
|
||||||
]:
|
]:
|
||||||
"""Create test batch with specified composition and mode"""
|
"""Create test batch with specified composition and mode"""
|
||||||
|
|
||||||
|
# Reset kernel cache to avoid cross-test contamination
|
||||||
|
reset_kernel_cache()
|
||||||
|
|
||||||
seq_lengths = self.generate_sequence_lengths(
|
seq_lengths = self.generate_sequence_lengths(
|
||||||
batch_size, batch_mode, 1, self.max_seq_len
|
batch_size, batch_mode, 1, self.max_seq_len
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user