diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index 1925baffc..30582a418 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -35,7 +35,7 @@ "\n", "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", "\n", - "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", + "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", @@ -79,7 +79,7 @@ "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", - " --max-loras-per-batch 1 --lora-backend triton \\\n", + " --max-loras-per-batch 1 \\\n", " --log-level warning \\\n", "\"\"\"\n", ")\n", @@ -139,7 +139,7 @@ " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", - " --max-loras-per-batch 2 --lora-backend triton \\\n", + " --max-loras-per-batch 2 \\\n", " --log-level warning \\\n", "\"\"\"\n", ")\n", @@ -214,7 +214,7 @@ " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --enable-lora \\\n", " --cuda-graph-max-bs 2 \\\n", - " --max-loras-per-batch 2 --lora-backend triton \\\n", + " --max-loras-per-batch 2 \\\n", " --max-lora-rank 256\n", " --lora-target-modules all\n", " --log-level warning\n", @@ -413,7 +413,7 @@ " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --enable-lora \\\n", " --cuda-graph-max-bs 8 \\\n", - " --max-loras-per-batch 3 --lora-backend triton \\\n", + " --max-loras-per-batch 3 \\\n", " --max-lora-rank 256 \\\n", " --lora-target-modules all \\\n", " --lora-paths \\\n", @@ -501,6 +501,48 @@ "terminate_process(server_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choosing LoRA Backend\n", + "\n", + "SGLang supports two LoRA backends that you can choose from using the `--lora-backend` argument:\n", + "\n", + "- `triton`: Default basic Triton-based backend.\n", + "- `csgmv`: Chunked SGMV backend optimized for high concurrency scenarios.\n", + "\n", + "The `csgmv` backend was recently introduced to improve performance especially at high-concurrency scenarios. Our benchmark shows that it achieves 20% to 80% latency improvements over the basic triton backend.\n", + "Currently it is at preview phase, we expect to make it our the default LoRA backend in future release. Before that, you can adopt it by manually setting the `--lora-backend` server config." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + " python3 -m sglang.launch_server \\\n", + " --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", + " --lora-backend csgmv \\\n", + " --max-loras-per-batch 16 \\\n", + " --lora-paths lora1=path/to/lora1 lora2=path/to/lora2\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 7c2c232d5..4d241f931 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: from sglang.srt.lora.backend.triton_backend import TritonLoRABackend return TritonLoRABackend - # elif name == "csgmv": - # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend + elif name == "csgmv": + from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend - # return ChunkedSgmvLoRABackend + return ChunkedSgmvLoRABackend elif name == "flashinfer": raise ValueError( "FlashInfer LoRA backend has been deprecated, please use `triton` instead." diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py new file mode 100644 index 000000000..bec21d601 --- /dev/null +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -0,0 +1,306 @@ +from typing import Optional + +import torch + +from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.triton_ops import ( + chunked_sgmv_lora_expand_forward, + chunked_sgmv_lora_shrink_forward, +) +from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class ChunkedSgmvLoRABackend(BaseLoRABackend): + """ + Chunked LoRA backend using segmented matrix-vector multiplication. + + This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm + introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to + segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially + when the LoRA distribution is skewed. + """ + + name = "csgmv" + + def __init__(self, max_loras_per_batch: int, device: torch.device): + super().__init__(max_loras_per_batch, device) + self.segment_size = 16 # TODO (lifuhuang): make it configurable? + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + return chunked_sgmv_lora_shrink_forward( + x, + weights, + self.batch_info, + ) + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + output_offset: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs + ) -> torch.Tensor: + # For simple lora B, we use slice offsets [0, output_dim] + output_dim = weights.shape[-2] + max_slice_size = output_dim + return chunked_sgmv_lora_expand_forward( + x=x, + lora_weight_b=weights, + batch_info=self.batch_info, + slice_offsets=output_offset, + max_slice_size=max_slice_size, + base_output=base_output, + ) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # qkv_lora_a: (num_lora, 3 * r, input_dim) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + assert isinstance(qkv_lora_b, torch.Tensor) + + lora_a_output = chunked_sgmv_lora_shrink_forward( + x, + qkv_lora_a, + self.batch_info, + num_slices=3, + ) + lora_output = chunked_sgmv_lora_expand_forward( + x=lora_a_output, + lora_weight_b=qkv_lora_b, + batch_info=self.batch_info, + slice_offsets=output_offset, + max_slice_size=max_qkv_out_dim, + base_output=base_output, + ) + return lora_output + + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: torch.Tensor, + output_offset: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # gate_up_lora_a: (num_lora, 2 * r, input_dim) + # gate_up_lora_b: (num_lora, 2 * output_dim, r) + assert isinstance(gate_up_lora_b, torch.Tensor) + output_dim = gate_up_lora_b.shape[-2] // 2 + + # lora_a_output: (s, 2 * r) + lora_a_output = chunked_sgmv_lora_shrink_forward( + x, + gate_up_lora_a, + self.batch_info, + num_slices=2, + ) + lora_output = chunked_sgmv_lora_expand_forward( + x=lora_a_output, + lora_weight_b=gate_up_lora_b, + batch_info=self.batch_info, + slice_offsets=output_offset, + max_slice_size=output_dim, + base_output=base_output, + ) + return lora_output + + def prepare_lora_batch( + self, + forward_batch: ForwardBatch, + weight_indices: list[int], + lora_ranks: list[int], + scalings: list[float], + batch_info: Optional[LoRABatchInfo] = None, + ): + permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation( + weight_indices, forward_batch + ) + + seg_weight_indices, seg_indptr = self._get_segments_info( + weight_indices_reordered + ) + num_segments = len(seg_weight_indices) + + lora_ranks_tensor = torch.tensor( + lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" + ) + scalings_tensor = torch.tensor( + scalings, dtype=torch.float, pin_memory=True, device="cpu" + ) + + if batch_info is None: + batch_info = LoRABatchInfo( + bs=forward_batch.batch_size, + num_segments=num_segments, + use_cuda_graph=False, + seg_indptr=torch.empty( + (num_segments + 1,), dtype=torch.int32, device=self.device + ), + weight_indices=torch.empty( + (num_segments,), dtype=torch.int32, device=self.device + ), + lora_ranks=torch.empty( + (self.max_loras_per_batch,), dtype=torch.int32, device=self.device + ), + scalings=torch.empty( + (self.max_loras_per_batch,), dtype=torch.float, device=self.device + ), + permutation=torch.empty( + (len(permutation),), dtype=torch.int32, device=self.device + ), + # Not used in chunked kernels + max_len=None, + seg_lens=None, + ) + else: + batch_info.bs = forward_batch.batch_size + batch_info.num_segments = num_segments + + # Copy to device asynchronously + batch_info.lora_ranks[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + batch_info.scalings[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + batch_info.weight_indices[:num_segments].copy_( + seg_weight_indices, non_blocking=True + ) + batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True) + batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True) + + self.batch_info = batch_info + + @staticmethod + def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch): + """ + Computes permutation indices for reordering tokens by their LoRA adapter assignments. + + This function implements the "gather" step in Chunked Segmented Gather Matrix Vector + multiplication by creating a permutation that groups tokens by their LoRA adapter. + Tokens using the same LoRA adapter are placed together to enable efficient batched + computation. + + Example: + seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0] + extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens] + + # Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total) + # Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together) + # weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter) + + Args: + seq_weight_indices: List of LoRA adapter indices for each sequence + forward_batch (ForwardBatch): Batch information containing sequence lengths + + Returns: + tuple: (permutation, weights_reordered) where: + - permutation: Token reordering indices to group by adapter + - weights_reordered: Sorted adapter indices for each token + """ + with torch.device("cpu"): + seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32) + + seg_lens_cpu = ( + torch.tensor( + forward_batch.extend_seq_lens_cpu, + dtype=torch.int32, + ) + if forward_batch.forward_mode.is_extend() + else torch.ones(forward_batch.batch_size, dtype=torch.int32) + ) + + row_weight_indices = torch.repeat_interleave( + seq_weight_indices, seg_lens_cpu + ) + permutation = torch.empty( + (len(row_weight_indices),), dtype=torch.long, pin_memory=True + ) + torch.argsort(row_weight_indices, stable=True, out=permutation) + weights_reordered = row_weight_indices[permutation] + + return permutation, weights_reordered + + def _get_segments_info(self, weights_reordered: torch.Tensor): + """ + Computes segment information for chunked SGMV operations. + + This function takes the reordered weight indices and creates segments of fixed size + (self.segment_size) for efficient kernel execution. Each segment contains tokens + that use the same LoRA adapter, enabling vectorized computation. + + The segmentation is necessary because: + 1. GPU kernels work efficiently on fixed-size blocks + 2. Large groups of tokens using the same adapter are split into manageable chunks + 3. Each segment can be processed independently in parallel + + Example: + weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1 + segment_size = 3 + + # Creates segments: + # Segment 0: tokens 0-2 (adapter 0), length=3 + # Segment 1: tokens 3-4 (adapter 0), length=2 + # Segment 2: token 5 (adapter 1), length=1 + + # Returns: + # weight_indices_list: [0, 0, 1] (adapter for each segment) + # seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries) + + Args: + weights_reordered (torch.Tensor): Sorted adapter indices for each token + + Returns: + tuple: (weight_indices_list, seg_indptr) where: + - weight_indices_list: LoRA adapter index for each segment + - seg_indptr: Cumulative segment boundaries (CSR-style indptr) + """ + with torch.device("cpu"): + unique_weights, counts = torch.unique_consecutive( + weights_reordered, return_counts=True + ) + + weight_indices_list = [] + seg_lens_list = [] + + for weight_idx, group_len in zip(unique_weights, counts): + group_len = group_len.item() + num_segs = (group_len + self.segment_size - 1) // self.segment_size + + weight_indices_list.extend([weight_idx.item()] * num_segs) + seg_lens_list.extend([self.segment_size] * (num_segs - 1)) + seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size) + + seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32) + + weight_indices_list = torch.tensor( + weight_indices_list, dtype=torch.int32, pin_memory=True + ) + + seg_indptr = torch.empty( + (len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True + ) + seg_indptr[0] = 0 + seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + + return weight_indices_list, seg_indptr diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index e7569624c..08d4c296f 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -28,14 +28,15 @@ from torch import nn from sglang.srt.configs.load_config import LoadConfig from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.lora.backend.base_backend import BaseLoRABackend - -# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend +from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend from sglang.srt.lora.backend.triton_backend import TritonLoRABackend from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_loader.loader import DefaultModelLoader logger = logging.getLogger(__name__) +SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend) + class LoRALayer(nn.Module): def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): @@ -48,6 +49,7 @@ class LoRALayer(nn.Module): class LoRAAdapter(nn.Module): + def __init__( self, uid: str, @@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module): gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") if up_name not in weights: weights[up_name] = torch.zeros_like(weights[weight_name]) - assert isinstance(self.lora_backend, TritonLoRABackend), ( - f"LoRA weight initialization currently only supported for 'triton' backend. " + assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), ( + f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}" f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"or consider implementing custom initialization logic for other backends." ) diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index da55e8fd5..74a2e84a2 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,3 +1,5 @@ +from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward +from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd @@ -8,4 +10,6 @@ __all__ = [ "qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd", + "chunked_sgmv_lora_shrink_forward", + "chunked_sgmv_lora_expand_forward", ] diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py new file mode 100644 index 000000000..7ea00e568 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -0,0 +1,211 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _chunked_lora_expand_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_indptr, + weight_indices, + lora_ranks, + permutation, + num_segs, + # For fused output scaling + scalings, + # Offsets of q/k/v slice on output dimension + slice_offsets, + # Meta parameters + NUM_SLICES: tl.constexpr, + MAX_RANK: tl.constexpr, # K = R + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computes a chunked SGMV for LoRA expand operations. + + When a sequence's rank is 0, the kernel is essentially a no-op, following + the convention in pytorch where the product of two matrices of shape (m, 0) + and (0, n) is an all-zero matrix of shape (m, n). + + Args: + x (Tensor): The input tensor, which is the result of the LoRA A projection. + Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the + batch and K is the maximum LoRA rank. + weights (Tensor): The LoRA B weights for all adapters. + Shape: (num_lora, output_dim, K). + output (Tensor): The output tensor where the result is stored. + Shape: (s, output_dim). + """ + tl.static_assert(NUM_SLICES <= 3) + + pid_s = tl.program_id(axis=2) + if pid_s >= num_segs: + return + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + w_index = tl.load(weight_indices + pid_s) + cur_rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel is a no-op. + if cur_rank == 0: + return + + seg_start = tl.load(seg_indptr + pid_s) + seg_end = tl.load(seg_indptr + pid_s + 1) + + slice_id = tl.program_id(axis=1) + slice_start = tl.load(slice_offsets + slice_id) + slice_end = tl.load(slice_offsets + slice_id + 1) + + scaling = tl.load(scalings + w_index) + # Adjust K (rank) according to the specific LoRA adapter + cur_rank = tl.minimum(MAX_RANK, cur_rank) + + # Map logical sequence index to physical index + s_offset_logical = tl.arange(0, BLOCK_S) + seg_start + s_offset_physical = tl.load( + permutation + s_offset_logical, mask=s_offset_logical < seg_end + ) + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + pid_n = tl.program_id(axis=0) + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = ( + x + + slice_id * cur_rank * x_stride_1 + + (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iterate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(cur_rank, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset_logical[:, None] < seg_end) + & (k_offset[None, :] < cur_rank - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < cur_rank - k * BLOCK_K) + & (n_offset[None, :] < slice_end), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = output + ( + s_offset_physical[:, None] * output_stride_0 + + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset_logical[:, None] < seg_end) & ( + n_offset[None, :] < slice_end + ) + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def chunked_sgmv_lora_expand_forward( + x: torch.Tensor, + lora_weight_b: torch.Tensor, + batch_info: LoRABatchInfo, + slice_offsets: torch.Tensor, + max_slice_size: int, + base_output: torch.Tensor = None, +) -> torch.Tensor: + + # x: (s, slice_num * r) + # lora_weight_b: (num_lora, output_dim, r) + # slice_offsets: boundaries for different slices in the output dimension + # output: (s, output_dim) + + # Compute lora_output with shape (s, output_dim) as follows: + # For each slice i, accumulates: + # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + max_lora_rank = lora_weight_b.shape[-1] + output_dim = lora_weight_b.shape[-2] + num_slices = len(slice_offsets) - 1 + assert input_dim == num_slices * max_lora_rank + + # TODO (lifuhuang): fine-tune per operation + BLOCK_M = 16 + BLOCK_K = 16 + BLOCK_N = 64 + + num_segments = batch_info.num_segments + + grid = ( + triton.cdiv(max_slice_size, BLOCK_N), + num_slices, # number of slices in the input/output + batch_info.bs if batch_info.use_cuda_graph else num_segments, + ) + + if base_output is None: + output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype) + else: + output = base_output + + _chunked_lora_expand_kernel[grid]( + x=x, + weights=lora_weight_b, + output=output, + x_stride_0=x.stride(0), + x_stride_1=x.stride(1), + w_stride_0=lora_weight_b.stride(0), + w_stride_1=lora_weight_b.stride(1), + w_stride_2=lora_weight_b.stride(2), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + seg_indptr=batch_info.seg_indptr, + weight_indices=batch_info.weight_indices, + lora_ranks=batch_info.lora_ranks, + permutation=batch_info.permutation, + num_segs=num_segments, + scalings=batch_info.scalings, + slice_offsets=slice_offsets, + # constants + NUM_SLICES=num_slices, + MAX_RANK=max_lora_rank, + BLOCK_S=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + return output diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py new file mode 100644 index 000000000..90687775b --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -0,0 +1,177 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _chunked_lora_shrink_kernel( + # Pointers to matrices + x, + weights, + output, + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths,ranks and weight id + seg_indptr, + weight_indices, + lora_ranks, + permutation, + num_segs, + # Meta parameters + N: tl.constexpr, # num_slices * r + K: tl.constexpr, # input_dim + NUM_SLICES: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computes a chunked SGMV for LoRA shrink operations. + + The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices] + stores the product of the input `x` and the LoRA weights for the corresponding + sequence. This implies that when rank is 0, the kernel is essentially a no-op, + as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty). + + Args: + x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s` + is the sum of all sequence lengths in the batch. + weights (torch.Tensor): The LoRA A weights for all available adapters, + with shape `(num_lora, N, K)` where N = num_slices * r. + output (torch.Tensor): The output tensor of shape `(s, N)`. + """ + pid_s = tl.program_id(1) + if pid_s >= num_segs: + return + + pid_n = tl.program_id(0) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + w_index = tl.load(weight_indices + pid_s) + rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel becomes a no-op as the output is always trivially correct. + if rank == 0: + return + + seg_start = tl.load(seg_indptr + pid_s) + seg_end = tl.load(seg_indptr + pid_s + 1) + + # Adjust N dim according to the specific LoRA adapter + cur_n = tl.minimum(N, rank * NUM_SLICES) + + # Map logical sequence index to physical index + s_offset_logical = tl.arange(0, BLOCK_S) + seg_start + s_offset_physical = tl.load( + permutation + s_offset_logical, mask=s_offset_logical < seg_end + ) + + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = x + ( + s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iterate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset_logical[:, None] < seg_end) + & (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = output + ( + s_offset_physical[:, None] * output_stride_0 + + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def chunked_sgmv_lora_shrink_forward( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoRABatchInfo, + num_slices: int = 1, +) -> torch.Tensor: + # x: (s, input_dim) + # weights: (num_lora, num_slices * r, input_dim) + # output: (s, num_slices * r) + # num_slices: qkv=3, gate_up=2, others=1 + # when called with multiple slices, the weights.shape[-2] will be num_slices * r + # input_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + # Block shapes + # TODO (lifuhuang): experiment with split-k + BLOCK_S = 16 + BLOCK_N = 16 + BLOCK_K = 256 + + S = x.shape[0] + N = weights.shape[1] + K = weights.shape[2] + assert x.shape[-1] == K + + num_segments = batch_info.num_segments + grid = ( + triton.cdiv(N, BLOCK_N), + batch_info.bs if batch_info.use_cuda_graph else num_segments, + ) + + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + _chunked_lora_shrink_kernel[grid]( + x=x, + weights=weights, + output=output, + x_stride_0=x.stride(0), + x_stride_1=x.stride(1), + w_stride_0=weights.stride(0), + w_stride_1=weights.stride(1), + w_stride_2=weights.stride(2), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + seg_indptr=batch_info.seg_indptr, + weight_indices=batch_info.weight_indices, + lora_ranks=batch_info.lora_ranks, + permutation=batch_info.permutation, + num_segs=num_segments, + # constants + N=N, + K=K, + NUM_SLICES=num_slices, + BLOCK_S=BLOCK_S, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + return output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 511b95b01..0272caba1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -110,6 +110,8 @@ ATTENTION_BACKEND_CHOICES = [ "ascend", ] +LORA_BACKEND_CHOICES = ["triton", "csgmv"] + DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] @@ -1601,7 +1603,8 @@ class ServerArgs: parser.add_argument( "--lora-backend", type=str, - default="triton", + choices=LORA_BACKEND_CHOICES, + default=ServerArgs.lora_backend, help="Choose the kernel backend for multi-LoRA serving.", ) diff --git a/test/srt/lora/test_chunked_sgmv_backend.py b/test/srt/lora/test_chunked_sgmv_backend.py new file mode 100644 index 000000000..051f8e08d --- /dev/null +++ b/test/srt/lora/test_chunked_sgmv_backend.py @@ -0,0 +1,740 @@ +import random +import unittest +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch + +from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend +from sglang.srt.lora.triton_ops import ( + chunked_sgmv_lora_expand_forward, + chunked_sgmv_lora_shrink_forward, +) +from sglang.srt.lora.utils import LoRABatchInfo + + +def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Matrix multiplication with mixed precision handling for float16""" + result = torch.matmul(a.float(), b.float()) + return result.to(a.dtype) + + +class BatchComposition(Enum): + UNIFORM = "uniform" + MIXED = "mixed" + SKEWED = "skewed" + NONE = "_NO_LORA_" + + +class BatchMode(Enum): + PREFILL = "prefill" + DECODE = "decode" + + +def reference_sgmv_shrink( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoRABatchInfo, + seq_lengths: List[int], + lora_assignments: List[str], + num_slices: int = 1, +) -> torch.Tensor: + """ + Simple sequence-level reference implementation of SGMV shrink operation. + + Args: + x: (total_seq_len, input_dim) - Input activations + weights: (num_loras, num_slices * max_rank, input_dim) - LoRA A weights + batch_info: Batch information (only used for lora_ranks) + seq_lengths: Length of each sequence + lora_assignments: LoRA name for each sequence + num_slices: Number of slices (3 for QKV, 2 for gate_up, 1 for others) + + Returns: + output: (total_seq_len, num_slices * max_rank) - Intermediate activations + """ + if weights.numel() == 0: + total_seq_len = x.shape[0] + return torch.zeros(total_seq_len, 0, dtype=x.dtype, device=x.device) + + total_seq_len, input_dim = x.shape + num_loras, weight_out_dim, _ = weights.shape + max_rank = weight_out_dim // num_slices + + output = torch.zeros( + total_seq_len, num_slices * max_rank, dtype=x.dtype, device=x.device + ) + + unique_loras = sorted(set(lora_assignments)) + lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)} + lora_ranks = batch_info.lora_ranks.cpu().numpy() + + token_offset = 0 + for seq_len, lora_name in zip(seq_lengths, lora_assignments): + if seq_len == 0: + continue + + lora_idx = lora_name_to_idx[lora_name] + rank = lora_ranks[lora_idx] + + if rank > 0: + x_seq = x[token_offset : token_offset + seq_len, :] + w_seq = weights[lora_idx, : num_slices * rank, :] + + result = safe_matmul(x_seq, w_seq.t()) + output[token_offset : token_offset + seq_len, : num_slices * rank] = result + + token_offset += seq_len + + return output + + +def reference_sgmv_expand( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoRABatchInfo, + seq_lengths: List[int], + lora_assignments: List[str], + slice_offsets: torch.Tensor, + max_slice_size: int, + base_output: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Simple sequence-level reference implementation of SGMV expand operation. + + Args: + x: (total_seq_len, num_slices * max_rank) - Intermediate activations + weights: (num_loras, output_dim, max_rank) - LoRA B weights + batch_info: Batch information (only used for lora_ranks) + seq_lengths: Length of each sequence + lora_assignments: LoRA name for each sequence + slice_offsets: Tensor defining slice boundaries + max_slice_size: Maximum slice size for chunking + base_output: Optional base output to accumulate into + + Returns: + output: (total_seq_len, total_output_dim) - Final output + """ + if weights.numel() == 0: + total_seq_len = x.shape[0] + total_output_dim = slice_offsets[-1].item() if len(slice_offsets) > 0 else 0 + return torch.zeros( + total_seq_len, total_output_dim, dtype=x.dtype, device=x.device + ) + + total_seq_len, _ = x.shape + + num_slices = len(slice_offsets) - 1 + + if base_output is not None: + output = base_output.clone() + else: + total_output_dim = slice_offsets[-1].item() + output = torch.zeros( + total_seq_len, total_output_dim, dtype=x.dtype, device=x.device + ) + + unique_loras = sorted(set(lora_assignments)) + lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)} + lora_ranks = batch_info.lora_ranks.cpu().numpy() + + token_offset = 0 + for seq_len, lora_name in zip(seq_lengths, lora_assignments): + if seq_len == 0: + continue + + lora_idx = lora_name_to_idx[lora_name] + lora_rank = lora_ranks[lora_idx] + + if lora_rank > 0: + # Extract sequence intermediate activations + x_seq = x[ + token_offset : token_offset + seq_len, : num_slices * lora_rank + ] # (seq_len, num_slices * rank) + + for slice_idx in range(num_slices): + slice_start_input = slice_idx * lora_rank + slice_end_input = (slice_idx + 1) * lora_rank + + slice_start_output = slice_offsets[slice_idx].item() + slice_end_output = slice_offsets[slice_idx + 1].item() + + x_slice = x_seq[:, slice_start_input:slice_end_input] # (seq_len, rank) + w_slice = weights[ + lora_idx, slice_start_output:slice_end_output, :lora_rank + ] # (slice_dim, rank) + + result = safe_matmul(x_slice, w_slice.t()) # (seq_len, slice_dim) + output[ + token_offset : token_offset + seq_len, + slice_start_output:slice_end_output, + ] += result + + token_offset += seq_len + + return output + + +class TestChunkedSGMV(unittest.TestCase): + + # Test configuration constants + RTOL = 1e-3 + ATOL = 1e-3 + DEFAULT_BATCH_SIZE = 8 + + def _compare_shrink_outputs( + self, + chunked_output: torch.Tensor, + reference_output: torch.Tensor, + seq_lengths: List[int], + lora_assignments: List[str], + batch_info: LoRABatchInfo, + num_slices: int, + test_name: str, + ): + """ + Compare only the valid portions of shrink outputs. + + The chunked SGMV shrink kernel only guarantees correctness for + output[seq_start:seq_end, :rank * num_slices] for each sequence. + """ + # Create mapping from LoRA names to indices and ranks + unique_loras = sorted(set(lora_assignments)) + lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)} + lora_ranks = batch_info.lora_ranks.cpu().numpy() + + token_offset = 0 + for seq_idx, (seq_len, lora_name) in enumerate( + zip(seq_lengths, lora_assignments) + ): + if seq_len == 0: + continue + + lora_idx = lora_name_to_idx[lora_name] + rank = lora_ranks[lora_idx] + + if rank > 0: + # Only compare the valid columns for this sequence + valid_cols = num_slices * rank + + chunked_seq = chunked_output[ + token_offset : token_offset + seq_len, :valid_cols + ] + reference_seq = reference_output[ + token_offset : token_offset + seq_len, :valid_cols + ] + + torch.testing.assert_close( + chunked_seq, + reference_seq, + rtol=self.RTOL, + atol=self.ATOL, + msg=f"Shrink operation failed for {test_name}, sequence {seq_idx} ({lora_name})", + ) + + token_offset += seq_len + + def setUp(self): + """Set up common test parameters""" + torch.manual_seed(42) + random.seed(42) + + self.device = torch.device("cuda") + self.dtype = torch.float16 + self.input_dim = 2560 # Hidden dimension + self.max_seq_len = 1024 + + # LoRA configurations: name -> (rank, output_q, output_k, output_v) + self.lora_configs = { + "lora_A": (8, 4096, 1024, 1024), + "lora_B": (16, 4096, 1024, 1024), + "lora_C": (32, 4096, 1024, 1024), + "_NO_LORA_": (0, 4096, 1024, 1024), + } + + # QKV slice offsets: 4096 (Q) + 1024 (K) + 1024 (V) = 6144 total + self.slice_offsets = torch.tensor( + [0, 4096, 5120, 6144], dtype=torch.int32, device=self.device + ) + self.max_slice_size = 4096 + + def generate_sequence_lengths( + self, + batch_size: int, + batch_mode: BatchMode = BatchMode.PREFILL, + min_len: int = 1, + max_len: int = None, + ) -> List[int]: + """Generate sequence lengths for a batch based on mode""" + if batch_mode == BatchMode.DECODE: + return [1] * batch_size + else: + if max_len is None: + max_len = self.max_seq_len + return [random.randint(min_len, max_len) for _ in range(batch_size)] + + def create_lora_weights( + self, lora_name: str, include_missing_k: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Create LoRA A and B weights for given configuration""" + rank, out_q, out_k, out_v = self.lora_configs[lora_name] + + if rank == 0: + lora_a = torch.empty( + 0, self.input_dim, dtype=self.dtype, device=self.device + ) + lora_b = torch.empty( + out_q + out_k + out_v, 0, dtype=self.dtype, device=self.device + ) + return lora_a, lora_b + + # Create LoRA A weights (3 slices for QKV) + lora_a = torch.randn( + 3 * rank, self.input_dim, dtype=self.dtype, device=self.device + ) + + if include_missing_k: + lora_a[rank : 2 * rank, :] = 0.0 + + # Create LoRA B weights (stacked Q, K, V) + total_output_dim = out_q + out_k + out_v + lora_b = torch.randn( + total_output_dim, rank, dtype=self.dtype, device=self.device + ) + + if include_missing_k: + lora_b[out_q : out_q + out_k, :] = 0.0 + + return lora_a, lora_b + + def create_batch_info( + self, + seq_lengths: List[int], + lora_assignments: List[Optional[str]], + batch_mode: BatchMode = BatchMode.PREFILL, + ) -> LoRABatchInfo: + """Create LoRABatchInfo using the same logic as chunked backend""" + unique_loras = sorted(set(lora_assignments)) + lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)} + + seq_weight_indices = [lora_name_to_idx[name] for name in lora_assignments] + + lora_ranks = [self.lora_configs[name][0] for name in unique_loras] + + def create_mock_batch(): + # Create a minimal mock ForwardBatch for the test + class MockForwardBatch: + def __init__(self, batch_size, seq_lengths): + self.batch_size = batch_size + self.extend_seq_lens_cpu = seq_lengths + self.forward_mode = MockForwardMode() + + class MockForwardMode: + def is_extend(self): + return batch_mode == BatchMode.PREFILL + + return MockForwardBatch(len(seq_lengths), seq_lengths) + + mock_batch = create_mock_batch() + + # Use the same functions as chunked backend + permutation, weights_reordered = ChunkedSgmvLoRABackend._get_permutation( + seq_weight_indices, mock_batch + ) + + # Create a minimal backend instance to access _get_segments_info + mock_backend = ChunkedSgmvLoRABackend(max_loras_per_batch=8, device=self.device) + weight_indices_list, seg_indptr = mock_backend._get_segments_info( + weights_reordered + ) + + scalings = [1.0] * len(unique_loras) + seg_indptr_tensor = seg_indptr.to(self.device) + weight_indices_tensor = weight_indices_list.to(self.device) + lora_ranks_tensor = ( + torch.tensor(lora_ranks, dtype=torch.int32, device=self.device) + if lora_ranks + else torch.empty(0, dtype=torch.int32, device=self.device) + ) + scalings_tensor = ( + torch.tensor(scalings, dtype=torch.float32, device=self.device) + if scalings + else torch.empty(0, dtype=torch.float32, device=self.device) + ) + permutation_tensor = permutation.to( + self.device, dtype=torch.int32 + ) # Convert to int32 for LoRABatchInfo + seq_lens_tensor = torch.tensor( + seq_lengths, dtype=torch.int32, device=self.device + ) + + return LoRABatchInfo( + use_cuda_graph=False, + bs=len(seq_lengths), + num_segments=len(weight_indices_list), # Number of segments, not sequences! + seg_indptr=seg_indptr_tensor, + weight_indices=weight_indices_tensor, + lora_ranks=lora_ranks_tensor, + scalings=scalings_tensor, + seg_lens=seq_lens_tensor, # Original sequence lengths for reference + max_len=max(seq_lengths) if seq_lengths else 0, + permutation=permutation_tensor, # Token reordering permutation + ) + + def stack_lora_weights( + self, weight_list: List[torch.Tensor], is_lora_a: bool + ) -> torch.Tensor: + """Stack LoRA weights from different adapters into a single tensor""" + if not weight_list: + return torch.empty(0, 0, 0, dtype=self.dtype, device=self.device) + + first_non_empty = next((w for w in weight_list if w.numel() > 0), None) + if first_non_empty is None: + return torch.empty( + len(weight_list), 0, 0, dtype=self.dtype, device=self.device + ) + if is_lora_a: + # LoRA A: (slice_num * rank, input_dim) -> (num_loras, slice_num * max_rank, input_dim) + max_rank = max(w.shape[0] // 3 if w.numel() > 0 else 0 for w in weight_list) + final_shape = (len(weight_list), 3 * max_rank, self.input_dim) + else: + # LoRA B: (output_dim, rank) -> (num_loras, output_dim, max_rank) + max_rank = max(w.shape[1] if w.numel() > 0 else 0 for w in weight_list) + output_dim = first_non_empty.shape[0] + final_shape = (len(weight_list), output_dim, max_rank) + + stacked = torch.zeros(final_shape, dtype=self.dtype, device=self.device) + + for i, weight in enumerate(weight_list): + if weight.numel() > 0: + if is_lora_a: + stacked[i, : weight.shape[0], :] = weight + else: + stacked[i, :, : weight.shape[1]] = weight + + return stacked + + def create_test_batch( + self, + batch_composition: BatchComposition, + batch_size: int, + batch_mode: BatchMode = BatchMode.PREFILL, + include_missing_k: bool = False, + ) -> Tuple[ + torch.Tensor, + Dict[str, Tuple[torch.Tensor, torch.Tensor]], + LoRABatchInfo, + List[int], + List[str], + ]: + """Create test batch with specified composition and mode""" + seq_lengths = self.generate_sequence_lengths( + batch_size, batch_mode, 1, self.max_seq_len + ) + if batch_composition == BatchComposition.UNIFORM: + lora_assignments = ["lora_A"] * batch_size + elif batch_composition == BatchComposition.MIXED: + lora_names = ["lora_A", "lora_B", "lora_C", None] + lora_assignments = [ + lora_names[i % len(lora_names)] for i in range(batch_size) + ] + elif batch_composition == BatchComposition.SKEWED: + num_minority = max(1, batch_size // 8) + lora_assignments = ["lora_A"] * num_minority + ["lora_B"] * ( + batch_size - num_minority + ) + random.shuffle(lora_assignments) + elif batch_composition == BatchComposition.NONE: + lora_assignments = [None] * batch_size + else: + raise ValueError(f"Unknown batch composition: {batch_composition}") + + total_seq_len = sum(seq_lengths) + x = torch.randn( + total_seq_len, self.input_dim, dtype=self.dtype, device=self.device + ) + + normalized_assignments = [ + name if name is not None else "_NO_LORA_" for name in lora_assignments + ] + unique_loras = set(normalized_assignments) + weights = {} + for lora_name in unique_loras: + weights[lora_name] = self.create_lora_weights(lora_name, include_missing_k) + + batch_info = self.create_batch_info( + seq_lengths, normalized_assignments, batch_mode + ) + + return x, weights, batch_info, seq_lengths, normalized_assignments + + def run_test_comparison( + self, + x: torch.Tensor, + weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], + batch_info: LoRABatchInfo, + seq_lengths: List[int], + lora_assignments: List[str], + test_name: str, + ): + """Run comparison between chunked and reference implementations""" + if not weights: # Handle case with no LoRA weights + return + + # Stack LoRA A weights + lora_a_weights = [weights[name][0] for name in sorted(weights.keys())] + stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True) + + # Stack LoRA B weights + lora_b_weights = [weights[name][1] for name in sorted(weights.keys())] + stacked_lora_b = self.stack_lora_weights(lora_b_weights, is_lora_a=False) + + # Test shrink operation + chunked_shrink = chunked_sgmv_lora_shrink_forward( + x, stacked_lora_a, batch_info, num_slices=3 + ) + reference_shrink = reference_sgmv_shrink( + x, stacked_lora_a, batch_info, seq_lengths, lora_assignments, num_slices=3 + ) + + # Only compare valid portions of shrink output (first rank * num_slices columns per sequence) + self._compare_shrink_outputs( + chunked_shrink, + reference_shrink, + seq_lengths, + lora_assignments, + batch_info, + num_slices=3, + test_name=test_name, + ) + + # Test expand operation + chunked_expand = chunked_sgmv_lora_expand_forward( + reference_shrink, + stacked_lora_b, + batch_info, + self.slice_offsets, + self.max_slice_size, + ) + reference_expand = reference_sgmv_expand( + reference_shrink, + stacked_lora_b, + batch_info, + seq_lengths, + lora_assignments, + self.slice_offsets, + self.max_slice_size, + ) + + torch.testing.assert_close( + chunked_expand, + reference_expand, + rtol=self.RTOL, + atol=self.ATOL, + msg=f"Expand operation failed for {test_name}", + ) + + # === Basic Operations Tests === + + def test_shrink_basic(self): + """Test basic shrink operation against PyTorch reference""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch(BatchComposition.UNIFORM, batch_size) + ) + + lora_a_weights = [weights[name][0] for name in sorted(weights.keys())] + stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True) + + chunked_shrink = chunked_sgmv_lora_shrink_forward( + x, stacked_lora_a, batch_info, num_slices=3 + ) + reference_shrink = reference_sgmv_shrink( + x, + stacked_lora_a, + batch_info, + seq_lengths, + lora_assignments, + num_slices=3, + ) + + torch.testing.assert_close( + chunked_shrink, reference_shrink, rtol=self.RTOL, atol=self.ATOL + ) + + def test_expand_basic(self): + """Test basic expand operation against PyTorch reference""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch(BatchComposition.UNIFORM, batch_size) + ) + + lora_a_weights = [weights[name][0] for name in sorted(weights.keys())] + stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True) + + intermediate = reference_sgmv_shrink( + x, + stacked_lora_a, + batch_info, + seq_lengths, + lora_assignments, + num_slices=3, + ) + + lora_b_weights = [weights[name][1] for name in sorted(weights.keys())] + stacked_lora_b = self.stack_lora_weights( + lora_b_weights, is_lora_a=False + ) + + chunked_expand = chunked_sgmv_lora_expand_forward( + intermediate, + stacked_lora_b, + batch_info, + self.slice_offsets, + self.max_slice_size, + ) + reference_expand = reference_sgmv_expand( + intermediate, + stacked_lora_b, + batch_info, + seq_lengths, + lora_assignments, + self.slice_offsets, + self.max_slice_size, + ) + + torch.testing.assert_close( + chunked_expand, reference_expand, rtol=self.RTOL, atol=self.ATOL + ) + + # === QKV Operations Test === + + def test_qkv_missing_projections(self): + """Test QKV operations with missing k_proj (Qwen3 scenario)""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch( + BatchComposition.MIXED, batch_size, include_missing_k=True + ) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"QKV missing k_proj batch_size={batch_size}", + ) + + # === Batch Composition Tests === + + def test_uniform_lora_batch(self): + """All sequences use same LoRA, random sequence lengths""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch(BatchComposition.UNIFORM, batch_size) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"uniform batch_size={batch_size}", + ) + + def test_evenly_mixed_lora_batch(self): + """Sequences evenly distributed across LoRAs, random lengths""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch(BatchComposition.MIXED, batch_size) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"mixed batch_size={batch_size}", + ) + + def test_highly_skewed_lora_batch(self): + """Highly uneven LoRA distribution, random lengths""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch(BatchComposition.SKEWED, batch_size) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"skewed batch_size={batch_size}", + ) + + # === Decode Mode Tests === + + def test_decode_uniform_lora_batch(self): + """Decode mode: All sequences use same LoRA, all length 1""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch( + BatchComposition.UNIFORM, batch_size, BatchMode.DECODE + ) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"decode uniform batch_size={batch_size}", + ) + + def test_decode_mixed_lora_batch(self): + """Decode mode: Sequences distributed across LoRAs, all length 1""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch( + BatchComposition.MIXED, batch_size, BatchMode.DECODE + ) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"decode mixed batch_size={batch_size}", + ) + + def test_decode_skewed_lora_batch(self): + """Decode mode: Highly uneven LoRA distribution, all length 1""" + for batch_size in [1, 2, 16, 64]: + with self.subTest(batch_size=batch_size): + x, weights, batch_info, seq_lengths, lora_assignments = ( + self.create_test_batch( + BatchComposition.SKEWED, batch_size, BatchMode.DECODE + ) + ) + self.run_test_comparison( + x, + weights, + batch_info, + seq_lengths, + lora_assignments, + f"decode skewed batch_size={batch_size}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b7ae98269..1712d5d42 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -24,6 +24,7 @@ suites = { TestFile("lora/test_lora_update.py", 400), TestFile("lora/test_lora_qwen3.py", 97), TestFile("lora/test_lora_radix_cache.py", 100), + TestFile("lora/test_chunked_sgmv_backend.py", 30), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),