[3/4] Speed up CSGMV backend perf by 10% through dynamic chunking + kernel optimization (#10592)

This commit is contained in:
Lifu Huang
2025-09-20 22:47:48 -07:00
committed by GitHub
parent 720c1c8ca3
commit 08ecd0aa2a
10 changed files with 158 additions and 84 deletions

View File

@@ -12,6 +12,8 @@ from sglang.srt.lora.triton_ops import (
)
from sglang.srt.lora.utils import LoRABatchInfo
CHUNK_SIZE = 16
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matrix multiplication with mixed precision handling for float16"""
@@ -343,9 +345,15 @@ class TestChunkedSGMV(unittest.TestCase):
)
# Create a minimal backend instance to access _get_segments_info
mock_backend = ChunkedSgmvLoRABackend(max_loras_per_batch=8, device=self.device)
mock_server_args = type(
"ServerArgs", (object,), {"max_lora_chunk_size": "MOCK_NEVER_USED"}
)
mock_backend = ChunkedSgmvLoRABackend(
max_loras_per_batch=8, device=self.device, server_args=mock_server_args
)
weight_indices_list, seg_indptr = mock_backend._get_segments_info(
weights_reordered
weights_reordered,
chunk_size=CHUNK_SIZE,
)
scalings = [1.0] * len(unique_loras)
@@ -377,7 +385,7 @@ class TestChunkedSGMV(unittest.TestCase):
lora_ranks=lora_ranks_tensor,
scalings=scalings_tensor,
seg_lens=seq_lens_tensor, # Original sequence lengths for reference
max_len=max(seq_lengths) if seq_lengths else 0,
max_len=CHUNK_SIZE,
permutation=permutation_tensor, # Token reordering permutation
)
@@ -515,6 +523,7 @@ class TestChunkedSGMV(unittest.TestCase):
batch_info,
self.slice_offsets,
self.max_slice_size,
base_output=None,
)
reference_expand = reference_sgmv_expand(
reference_shrink,
@@ -594,6 +603,7 @@ class TestChunkedSGMV(unittest.TestCase):
batch_info,
self.slice_offsets,
self.max_slice_size,
base_output=None,
)
reference_expand = reference_sgmv_expand(
intermediate,