diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index 30582a418..3a5e93148 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -41,6 +41,8 @@ "\n", "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n", "\n", + "* `--max-lora-chunk-size`: Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance. Please tune this value based on your hardware and workload as needed. Defaults to 16.\n", + "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to." diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index bec21d601..2c460d7c1 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -9,6 +9,9 @@ from sglang.srt.lora.triton_ops import ( ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import ServerArgs + +MIN_CHUNK_SIZE = 16 class ChunkedSgmvLoRABackend(BaseLoRABackend): @@ -23,17 +26,23 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): name = "csgmv" - def __init__(self, max_loras_per_batch: int, device: torch.device): + def __init__( + self, + max_loras_per_batch: int, + device: torch.device, + server_args: ServerArgs, + ): super().__init__(max_loras_per_batch, device) - self.segment_size = 16 # TODO (lifuhuang): make it configurable? + self.max_chunk_size = server_args.max_lora_chunk_size def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: return chunked_sgmv_lora_shrink_forward( - x, - weights, - self.batch_info, + x=x, + weights=weights, + batch_info=self.batch_info, + num_slices=1, ) def run_lora_b_sgemm( @@ -50,7 +59,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): max_slice_size = output_dim return chunked_sgmv_lora_expand_forward( x=x, - lora_weight_b=weights, + weights=weights, batch_info=self.batch_info, slice_offsets=output_offset, max_slice_size=max_slice_size, @@ -75,14 +84,14 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): assert isinstance(qkv_lora_b, torch.Tensor) lora_a_output = chunked_sgmv_lora_shrink_forward( - x, - qkv_lora_a, - self.batch_info, + x=x, + weights=qkv_lora_a, + batch_info=self.batch_info, num_slices=3, ) lora_output = chunked_sgmv_lora_expand_forward( x=lora_a_output, - lora_weight_b=qkv_lora_b, + weights=qkv_lora_b, batch_info=self.batch_info, slice_offsets=output_offset, max_slice_size=max_qkv_out_dim, @@ -109,14 +118,14 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): # lora_a_output: (s, 2 * r) lora_a_output = chunked_sgmv_lora_shrink_forward( - x, - gate_up_lora_a, - self.batch_info, + x=x, + weights=gate_up_lora_a, + batch_info=self.batch_info, num_slices=2, ) lora_output = chunked_sgmv_lora_expand_forward( x=lora_a_output, - lora_weight_b=gate_up_lora_b, + weights=gate_up_lora_b, batch_info=self.batch_info, slice_offsets=output_offset, max_slice_size=output_dim, @@ -124,6 +133,33 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): ) return lora_output + def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int: + """ + Heuristically determine the chunk size based on token token number in a batch. + + Args: + forward_batch (ForwardBatch): The batch information containing sequence lengths. + + Returns: + The determined chunk size + """ + + if self.max_chunk_size <= MIN_CHUNK_SIZE: + return MIN_CHUNK_SIZE + + num_tokens = ( + forward_batch.extend_num_tokens + if forward_batch.forward_mode.is_extend() + else forward_batch.batch_size + ) + if num_tokens >= 256: + chunk_size = 128 + elif num_tokens >= 64: + chunk_size = 32 + else: # num_tokens < 64 + chunk_size = 16 + return min(self.max_chunk_size, chunk_size) + def prepare_lora_batch( self, forward_batch: ForwardBatch, @@ -132,12 +168,16 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): scalings: list[float], batch_info: Optional[LoRABatchInfo] = None, ): + chunk_size = self._determine_chunk_size(forward_batch) + permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation( - weight_indices, forward_batch + seq_weight_indices=weight_indices, + forward_batch=forward_batch, ) seg_weight_indices, seg_indptr = self._get_segments_info( - weight_indices_reordered + weights_reordered=weight_indices_reordered, + chunk_size=chunk_size, ) num_segments = len(seg_weight_indices) @@ -152,6 +192,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): batch_info = LoRABatchInfo( bs=forward_batch.batch_size, num_segments=num_segments, + max_len=chunk_size, use_cuda_graph=False, seg_indptr=torch.empty( (num_segments + 1,), dtype=torch.int32, device=self.device @@ -169,12 +210,12 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): (len(permutation),), dtype=torch.int32, device=self.device ), # Not used in chunked kernels - max_len=None, seg_lens=None, ) else: batch_info.bs = forward_batch.batch_size batch_info.num_segments = num_segments + batch_info.max_len = chunk_size # Copy to device asynchronously batch_info.lora_ranks[: self.max_loras_per_batch].copy_( @@ -241,7 +282,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): return permutation, weights_reordered - def _get_segments_info(self, weights_reordered: torch.Tensor): + def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int): """ Computes segment information for chunked SGMV operations. @@ -269,6 +310,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): Args: weights_reordered (torch.Tensor): Sorted adapter indices for each token + chunk_size (int): Fixed size for each segment Returns: tuple: (weight_indices_list, seg_indptr) where: @@ -285,11 +327,11 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend): for weight_idx, group_len in zip(unique_weights, counts): group_len = group_len.item() - num_segs = (group_len + self.segment_size - 1) // self.segment_size + num_segs = (group_len + chunk_size - 1) // chunk_size weight_indices_list.extend([weight_idx.item()] * num_segs) - seg_lens_list.extend([self.segment_size] * (num_segs - 1)) - seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size) + seg_lens_list.extend([chunk_size] * (num_segs - 1)) + seg_lens_list.append(group_len - (num_segs - 1) * chunk_size) seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32) diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 7abeef770..f99e2c006 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -11,12 +11,18 @@ from sglang.srt.lora.triton_ops import ( ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import ServerArgs class TritonLoRABackend(BaseLoRABackend): name = "triton" - def __init__(self, max_loras_per_batch: int, device: torch.device): + def __init__( + self, + max_loras_per_batch: int, + device: torch.device, + **kwargs, + ): super().__init__(max_loras_per_batch, device) def run_lora_a_sgemm( @@ -30,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend): weights: torch.Tensor, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output) @@ -43,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend): max_qkv_out_dim: int, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: # x: (s, input_dim) @@ -69,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend): gate_up_lora_b: torch.Tensor, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: # x: (s, input_dim) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index d288323bf..cabc8cb3b 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -37,6 +37,7 @@ from sglang.srt.lora.utils import ( ) from sglang.srt.managers.io_struct import LoRAUpdateResult from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import replace_submodule logger = logging.getLogger(__name__) @@ -56,6 +57,7 @@ class LoRAManager: max_lora_rank: Optional[int] = None, target_modules: Optional[Iterable[str]] = None, lora_paths: Optional[List[LoRARef]] = None, + server_args: Optional[ServerArgs] = None, ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -72,6 +74,7 @@ class LoRAManager: self.lora_backend: BaseLoRABackend = backend_type( max_loras_per_batch=max_loras_per_batch, device=self.device, + server_args=server_args, ) # Initialize mutable internal state of the LoRAManager. diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py index 7ea00e568..52ace0dae 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -13,15 +13,6 @@ def _chunked_lora_expand_kernel( x, weights, output, - # Parameters of size - # Strides - x_stride_0, - x_stride_1, - w_stride_0, - w_stride_1, - w_stride_2, - output_stride_0, - output_stride_1, # Information on sequence lengths and weight id seg_indptr, weight_indices, @@ -34,8 +25,9 @@ def _chunked_lora_expand_kernel( slice_offsets, # Meta parameters NUM_SLICES: tl.constexpr, + OUTPUT_DIM: tl.constexpr, MAX_RANK: tl.constexpr, # K = R - BLOCK_S: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): @@ -57,6 +49,16 @@ def _chunked_lora_expand_kernel( """ tl.static_assert(NUM_SLICES <= 3) + x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK + x_stride_1: tl.constexpr = 1 + + w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK + w_stride_1: tl.constexpr = MAX_RANK + w_stride_2: tl.constexpr = 1 + + output_stride_0: tl.constexpr = OUTPUT_DIM + output_stride_1: tl.constexpr = 1 + pid_s = tl.program_id(axis=2) if pid_s >= num_segs: return @@ -83,7 +85,7 @@ def _chunked_lora_expand_kernel( cur_rank = tl.minimum(MAX_RANK, cur_rank) # Map logical sequence index to physical index - s_offset_logical = tl.arange(0, BLOCK_S) + seg_start + s_offset_logical = tl.arange(0, BLOCK_M) + seg_start s_offset_physical = tl.load( permutation + s_offset_logical, mask=s_offset_logical < seg_end ) @@ -105,7 +107,7 @@ def _chunked_lora_expand_kernel( ) # Iterate to compute the block in output matrix - partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(cur_rank, BLOCK_K)): x_tile = tl.load( x_ptrs, @@ -140,32 +142,37 @@ def _chunked_lora_expand_kernel( def chunked_sgmv_lora_expand_forward( x: torch.Tensor, - lora_weight_b: torch.Tensor, + weights: torch.Tensor, batch_info: LoRABatchInfo, slice_offsets: torch.Tensor, max_slice_size: int, - base_output: torch.Tensor = None, + base_output: Optional[torch.Tensor], ) -> torch.Tensor: # x: (s, slice_num * r) - # lora_weight_b: (num_lora, output_dim, r) + # weights: (num_lora, output_dim, r) # slice_offsets: boundaries for different slices in the output dimension # output: (s, output_dim) # Compute lora_output with shape (s, output_dim) as follows: # For each slice i, accumulates: - # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :]) + # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :]) + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 # Get dims - s = x.shape[0] + M = x.shape[0] input_dim = x.shape[1] - max_lora_rank = lora_weight_b.shape[-1] - output_dim = lora_weight_b.shape[-2] + OUTPUT_DIM = weights.shape[1] + MAX_RANK = weights.shape[2] num_slices = len(slice_offsets) - 1 - assert input_dim == num_slices * max_lora_rank + assert input_dim == num_slices * MAX_RANK # TODO (lifuhuang): fine-tune per operation - BLOCK_M = 16 + BLOCK_M = batch_info.max_len BLOCK_K = 16 BLOCK_N = 64 @@ -178,21 +185,14 @@ def chunked_sgmv_lora_expand_forward( ) if base_output is None: - output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype) + output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype) else: output = base_output _chunked_lora_expand_kernel[grid]( x=x, - weights=lora_weight_b, + weights=weights, output=output, - x_stride_0=x.stride(0), - x_stride_1=x.stride(1), - w_stride_0=lora_weight_b.stride(0), - w_stride_1=lora_weight_b.stride(1), - w_stride_2=lora_weight_b.stride(2), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), seg_indptr=batch_info.seg_indptr, weight_indices=batch_info.weight_indices, lora_ranks=batch_info.lora_ranks, @@ -202,8 +202,9 @@ def chunked_sgmv_lora_expand_forward( slice_offsets=slice_offsets, # constants NUM_SLICES=num_slices, - MAX_RANK=max_lora_rank, - BLOCK_S=BLOCK_M, + OUTPUT_DIM=OUTPUT_DIM, + MAX_RANK=MAX_RANK, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py index 90687775b..5091ba09a 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -11,14 +11,6 @@ def _chunked_lora_shrink_kernel( x, weights, output, - # Strides - x_stride_0, - x_stride_1, - w_stride_0, - w_stride_1, - w_stride_2, - output_stride_0, - output_stride_1, # Information on sequence lengths,ranks and weight id seg_indptr, weight_indices, @@ -29,7 +21,7 @@ def _chunked_lora_shrink_kernel( N: tl.constexpr, # num_slices * r K: tl.constexpr, # input_dim NUM_SLICES: tl.constexpr, - BLOCK_S: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): @@ -48,6 +40,16 @@ def _chunked_lora_shrink_kernel( with shape `(num_lora, N, K)` where N = num_slices * r. output (torch.Tensor): The output tensor of shape `(s, N)`. """ + x_stride_1: tl.constexpr = 1 + x_stride_0: tl.constexpr = K + + w_stride_0: tl.constexpr = N * K + w_stride_1: tl.constexpr = K + w_stride_2: tl.constexpr = 1 + + output_stride_0: tl.constexpr = N + output_stride_1: tl.constexpr = 1 + pid_s = tl.program_id(1) if pid_s >= num_segs: return @@ -70,7 +72,7 @@ def _chunked_lora_shrink_kernel( cur_n = tl.minimum(N, rank * NUM_SLICES) # Map logical sequence index to physical index - s_offset_logical = tl.arange(0, BLOCK_S) + seg_start + s_offset_logical = tl.arange(0, BLOCK_M) + seg_start s_offset_physical = tl.load( permutation + s_offset_logical, mask=s_offset_logical < seg_end ) @@ -85,7 +87,7 @@ def _chunked_lora_shrink_kernel( ) # Iterate to compute the block in output matrix - partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( x_ptrs, @@ -117,7 +119,7 @@ def chunked_sgmv_lora_shrink_forward( x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo, - num_slices: int = 1, + num_slices: int, ) -> torch.Tensor: # x: (s, input_dim) # weights: (num_lora, num_slices * r, input_dim) @@ -133,7 +135,7 @@ def chunked_sgmv_lora_shrink_forward( # Block shapes # TODO (lifuhuang): experiment with split-k - BLOCK_S = 16 + BLOCK_M = batch_info.max_len BLOCK_N = 16 BLOCK_K = 256 @@ -153,13 +155,6 @@ def chunked_sgmv_lora_shrink_forward( x=x, weights=weights, output=output, - x_stride_0=x.stride(0), - x_stride_1=x.stride(1), - w_stride_0=weights.stride(0), - w_stride_1=weights.stride(1), - w_stride_2=weights.stride(2), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), seg_indptr=batch_info.seg_indptr, weight_indices=batch_info.weight_indices, lora_ranks=batch_info.lora_ranks, @@ -169,7 +164,7 @@ def chunked_sgmv_lora_shrink_forward( N=N, K=K, NUM_SLICES=num_slices, - BLOCK_S=BLOCK_S, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 459c943b7..486e9b918 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -19,6 +19,9 @@ class LoRABatchInfo: # Number of segments. For triton backend, it is equal to batch size. num_segments: int + # Maximum segment length of current batch + max_len: int + # Indice pointers of each segment in shape (num_segments + 1, ) seg_indptr: torch.Tensor @@ -34,9 +37,6 @@ class LoRABatchInfo: # Lengths of each segments in shape (num_segments,) seg_lens: Optional[torch.Tensor] - # Maximum segment length of current batch - max_len: Optional[int] - # The logical (re)ordering of input rows (tokens), in shape (num_tokens,) permutation: Optional[torch.Tensor] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 210b21349..e517ddad4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1195,6 +1195,7 @@ class ModelRunner: max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, lora_paths=self.server_args.lora_paths, + server_args=self.server_args, ) def load_lora_adapter(self, lora_ref: LoRARef): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 98c0348e4..f92e3e1a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -268,6 +268,7 @@ class ServerArgs: max_loaded_loras: Optional[int] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" + max_lora_chunk_size: Optional[int] = 16 # Kernel backend attention_backend: Optional[str] = None @@ -1779,6 +1780,13 @@ class ServerArgs: default=ServerArgs.lora_backend, help="Choose the kernel backend for multi-LoRA serving.", ) + parser.add_argument( + "--max-lora-chunk-size", + type=int, + default=ServerArgs.max_lora_chunk_size, + choices=[16, 32, 64, 128], + help="Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance.", + ) # Kernel backend parser.add_argument( @@ -2779,6 +2787,12 @@ class ServerArgs: f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}" ) + if self.max_lora_chunk_size is not None: + assert ( + 16 <= self.max_lora_chunk_size <= 128 + and (self.max_lora_chunk_size & (self.max_lora_chunk_size - 1)) == 0 + ), "--max-lora-chunk-size must be a power of 2 between 16 and 128." + def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) smaller_tp = min(decode_tp, prefill_tp) diff --git a/test/srt/lora/test_chunked_sgmv_backend.py b/test/srt/lora/test_chunked_sgmv_backend.py index 051f8e08d..6df369f81 100644 --- a/test/srt/lora/test_chunked_sgmv_backend.py +++ b/test/srt/lora/test_chunked_sgmv_backend.py @@ -12,6 +12,8 @@ from sglang.srt.lora.triton_ops import ( ) from sglang.srt.lora.utils import LoRABatchInfo +CHUNK_SIZE = 16 + def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Matrix multiplication with mixed precision handling for float16""" @@ -343,9 +345,15 @@ class TestChunkedSGMV(unittest.TestCase): ) # Create a minimal backend instance to access _get_segments_info - mock_backend = ChunkedSgmvLoRABackend(max_loras_per_batch=8, device=self.device) + mock_server_args = type( + "ServerArgs", (object,), {"max_lora_chunk_size": "MOCK_NEVER_USED"} + ) + mock_backend = ChunkedSgmvLoRABackend( + max_loras_per_batch=8, device=self.device, server_args=mock_server_args + ) weight_indices_list, seg_indptr = mock_backend._get_segments_info( - weights_reordered + weights_reordered, + chunk_size=CHUNK_SIZE, ) scalings = [1.0] * len(unique_loras) @@ -377,7 +385,7 @@ class TestChunkedSGMV(unittest.TestCase): lora_ranks=lora_ranks_tensor, scalings=scalings_tensor, seg_lens=seq_lens_tensor, # Original sequence lengths for reference - max_len=max(seq_lengths) if seq_lengths else 0, + max_len=CHUNK_SIZE, permutation=permutation_tensor, # Token reordering permutation ) @@ -515,6 +523,7 @@ class TestChunkedSGMV(unittest.TestCase): batch_info, self.slice_offsets, self.max_slice_size, + base_output=None, ) reference_expand = reference_sgmv_expand( reference_shrink, @@ -594,6 +603,7 @@ class TestChunkedSGMV(unittest.TestCase): batch_info, self.slice_offsets, self.max_slice_size, + base_output=None, ) reference_expand = reference_sgmv_expand( intermediate,