From ea4bf12286fdba437230da5e8ad6dc00fc084c29 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sun, 6 Jul 2025 00:45:29 -0700 Subject: [PATCH] Fix division-by-zero bug in LoRA triton kernels. (#7785) --- .../srt/lora/triton_ops/gate_up_lora_b.py | 47 ++++++++++++------- .../sglang/srt/lora/triton_ops/qkv_lora_b.py | 47 ++++++++++++------- .../srt/lora/triton_ops/sgemm_lora_a.py | 36 ++++++++++---- .../srt/lora/triton_ops/sgemm_lora_b.py | 42 +++++++++++------ 4 files changed, 111 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py b/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py index ae242dc48..fc4574dd3 100644 --- a/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py @@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel( BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - # For fused output scaling and adding - fuse_scaling_add, + # For fused output scaling scalings, ): - # This kernel packs 2 sgemms (gate/up) into a single kernel. + """ + This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication + results are accumulated into the output tensor. - # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank - # weights: (num_lora, 2 * output_dim, K) - # output: (s, 2 * output_dim) + When a sequence's rank is 0, the kernel is essentially a no-op, following + the convention in pytorch where the product of two matrices of shape (m, 0) + and (0, n) is an all-zero matrix of shape (m, n). + + Args: + x (Tensor): The input tensor, which is the result of the LoRA A projection. + Shape: (s, 2 * K), where s is the sum of all sequence lengths in the + batch and K is the maximum LoRA rank. + weights (Tensor): The LoRA B weights for all adapters. + Shape: (num_lora, 2 * output_dim, K). + output (Tensor): The output tensor where the result is stored. + Shape: (s, 2 * output_dim). + """ # output_dim >> K # Current block computes sequence with batch_id, # which starts from row seg_start of x with length seg_len. # gate_up_id decides which of gate or up (0: gate, 1: up) batch_id = tl.program_id(axis=2) + w_index = tl.load(weight_indices + batch_id) + rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel is a no-op. + if rank == 0: + return + gate_up_id = tl.program_id(axis=1) pid = tl.program_id(axis=0) seg_len = tl.load(seg_lens + batch_id) - w_index = tl.load(weight_indices + batch_id) seg_start = tl.load(seg_indptr + batch_id) n_start = gate_up_id * output_dim # offset on output dim - rank = tl.load(lora_ranks + w_index) scaling = tl.load(scalings + w_index) # Adjust K (rank) according to the specific LoRA adapter @@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel( for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), other=0.0, ) w_tile = tl.load( w_ptrs, mask=(k_offset[:, None] < K - k * BLOCK_K) - and (n_offset[None, :] < output_dim), + & (n_offset[None, :] < output_dim), other=0.0, ) partial_sum += tl.dot(x_tile, w_tile) @@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel( output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) - output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim) - if fuse_scaling_add: - partial_sum += tl.load(output_ptr, mask=output_mask) + output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim) + partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) @@ -143,11 +157,9 @@ def gate_up_lora_b_fwd( ) if base_output is None: - output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype) - fuse_scaling_add = False + output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype) else: output = base_output - fuse_scaling_add = True _gate_up_lora_b_kernel[grid_b]( x, @@ -169,7 +181,6 @@ def gate_up_lora_b_fwd( BLOCK_S, BLOCK_OUT, BLOCK_R, - fuse_scaling_add, batch_info.scalings, ) diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py index 76f3f8671..e685e526b 100644 --- a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -33,29 +33,45 @@ def _qkv_lora_b_kernel( BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - # For fused output scaling and adding - fuse_scaling_add, + # For fused output scaling scalings, ): - # This kernel packs 3 sgemms (q/k/v) into a single kernel. + """ + This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication + results are accumulated into the output tensor. - # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank - # weights: (num_lora, N_Q + 2 * N_KV, K) - # output: (s, N_Q + 2 * N_KV) - # N_Q >> K, N_KV >> K + When a sequence's rank is 0, the kernel is essentially a no-op, following + the convention in pytorch where the product of two matrices of shape (m, 0) + and (0, n) is an all-zero matrix of shape (m, n). + + Args: + x (Tensor): The input tensor, which is the result of the LoRA A projection. + Shape: (s, 3 * K), where s is the sum of all sequence lengths in the + batch and K is the maximum LoRA rank. The second dimension is partitioned + for Q, K, and V. + weights (Tensor): The LoRA B weights for all adapters. + Shape: (num_lora, N_Q + 2 * N_KV, K). + output (Tensor): The output tensor where the result is stored. + Shape: (s, N_Q + 2 * N_KV). + """ # Current block computes sequence with batch_id, # which starts from row seg_start of x with length seg_len. # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) batch_id = tl.program_id(axis=2) + w_index = tl.load(weight_indices + batch_id) + rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel is a no-op. + if rank == 0: + return + qkv_id = tl.program_id(axis=1) pid = tl.program_id(axis=0) seg_len = tl.load(seg_lens + batch_id) - w_index = tl.load(weight_indices + batch_id) seg_start = tl.load(seg_indptr + batch_id) n_start = tl.load(n_offs + qkv_id) n_size = tl.load(n_offs + qkv_id + 1) - n_start - rank = tl.load(lora_ranks + w_index) scaling = tl.load(scalings + w_index) # Adjust K (rank) according to the specific LoRA adapter K = tl.minimum(K, rank) @@ -84,13 +100,12 @@ def _qkv_lora_b_kernel( for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), other=0.0, ) w_tile = tl.load( w_ptrs, - mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size), other=0.0, ) partial_sum += tl.dot(x_tile, w_tile) @@ -105,8 +120,7 @@ def _qkv_lora_b_kernel( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) - if fuse_scaling_add: - partial_sum += tl.load(output_ptr, mask=output_mask) + partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) @@ -153,11 +167,9 @@ def qkv_lora_b_fwd( ) if base_output is None: - output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) - fuse_scaling_add = False + output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype) else: output = base_output - fuse_scaling_add = True _qkv_lora_b_kernel[grid_b]( x, @@ -180,7 +192,6 @@ def qkv_lora_b_fwd( BLOCK_S, BLOCK_OUT, BLOCK_R, - fuse_scaling_add, batch_info.scalings, ) diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py index 201f75269..dded64bcf 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): + """ + Computes a segmented batched matrix multiplication for the LoRA A matrix. - # x: (s, K), s is the sum of sequence lengths - # weights: (num_lora, N, K) - # output: (s, N) + The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num] + stores the product of the input `x` and the LoRA weights for the corresponding + sequence. This implies that when rank is 0, the kernel is essentially a no-op, + as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty). + + Args: + x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s` + is the sum of all sequence lengths in the batch. + weights (torch.Tensor): The LoRA 'A' weights for all available adapters, + with shape `(num_lora, N, K)`. + output (torch.Tensor): The output tensor of shape `(s, N)`. + """ # Current block computes sequence with batch_id, # which starts from row seg_start of x with length seg_len batch_id = tl.program_id(axis=1) - pid = tl.program_id(axis=0) - seg_len = tl.load(seg_lens + batch_id) w_index = tl.load(weight_indices + batch_id) - seg_start = tl.load(seg_indptr + batch_id) rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel becomes a no-op as the output is always trivially correct. + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + # Adjust N (stack_num * max_rank) according to the specific LoRA adapter N = tl.minimum(N, rank * stack_num) @@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel( for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), other=0.0, ) w_tile = tl.load( w_ptrs, - mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), + mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N), other=0.0, ) partial_sum += tl.dot(x_tile, w_tile) @@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel( output_ptr = (output + seg_start * output_stride_0) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) - output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) + output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N) tl.store(output_ptr, partial_sum, mask=output_mask) diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py index 89fe2591f..357d32805 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel( BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - # For fused output scaling and adding - fuse_scaling_add, + # For fused output scaling scalings, ): - # x: (s, K), s is the sum of sequence lengths - # weights: (num_lora, N, K) - # output: (s, N) + """ + Computes a segmented batched matrix multiplication for the LoRA B matrix + and adds the result to the output in-place. + + When a sequence's rank is 0, the kernel is essentially a no-op, following + the convention in pytorch where the product of two matrices of shape (m, 0) + and (0, n) is an all-zero matrix of shape (m, n). + + Args: + x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication, + of shape `(s, K)`, where `s` is the total number of tokens. + weights (torch.Tensor): The LoRA 'B' weights for all available adapters, + with shape `(num_lora, N, K)`. + output (torch.Tensor): The output tensor of shape `(s, N)`. This can be + the base model's output for a fused add operation. + """ # Current block computes sequence with batch_id, # which starts from row seg_start of x with length seg_len batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel is a no-op. + if rank == 0: + return + pid = tl.program_id(axis=0) seg_len = tl.load(seg_lens + batch_id) - w_index = tl.load(weight_indices + batch_id) seg_start = tl.load(seg_indptr + batch_id) - rank = tl.load(lora_ranks + w_index) scaling = tl.load(scalings + w_index) # Adjust K (rank) according to the specific LoRA adapter K = tl.minimum(K, rank) @@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel( for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), other=0.0, ) w_tile = tl.load( @@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) output_mask = s_offset[:, None] < seg_len - if fuse_scaling_add: - partial_sum += tl.load(output_ptr, mask=output_mask) + partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) @@ -132,11 +147,9 @@ def sgemm_lora_b_fwd( ) if base_output is None: - output = torch.empty((S, N), device=x.device, dtype=x.dtype) - fuse_scaling_add = False + output = torch.zeros((S, N), device=x.device, dtype=x.dtype) else: output = base_output - fuse_scaling_add = True _sgemm_lora_b_kernel[grid]( x, @@ -158,7 +171,6 @@ def sgemm_lora_b_fwd( BLOCK_S, BLOCK_N, BLOCK_R, - fuse_scaling_add, batch_info.scalings, ) return output