Fix division-by-zero bug in LoRA triton kernels. (#7785)
This commit is contained in:
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
# For fused output scaling
|
||||
scalings,
|
||||
):
|
||||
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
||||
"""
|
||||
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
|
||||
results are accumulated into the output tensor.
|
||||
|
||||
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, 2 * output_dim, K)
|
||||
# output: (s, 2 * output_dim)
|
||||
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||
and (0, n) is an all-zero matrix of shape (m, n).
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
||||
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
|
||||
batch and K is the maximum LoRA rank.
|
||||
weights (Tensor): The LoRA B weights for all adapters.
|
||||
Shape: (num_lora, 2 * output_dim, K).
|
||||
output (Tensor): The output tensor where the result is stored.
|
||||
Shape: (s, 2 * output_dim).
|
||||
"""
|
||||
# output_dim >> K
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len.
|
||||
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
||||
batch_id = tl.program_id(axis=2)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
|
||||
# If rank is 0, this kernel is a no-op.
|
||||
if rank == 0:
|
||||
return
|
||||
|
||||
gate_up_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = gate_up_id * output_dim # offset on output dim
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
||||
and (n_offset[None, :] < output_dim),
|
||||
& (n_offset[None, :] < output_dim),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
|
||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_gate_up_lora_b_kernel[grid_b](
|
||||
x,
|
||||
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
batch_info.scalings,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
# For fused output scaling
|
||||
scalings,
|
||||
):
|
||||
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
||||
"""
|
||||
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
|
||||
results are accumulated into the output tensor.
|
||||
|
||||
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
||||
# output: (s, N_Q + 2 * N_KV)
|
||||
# N_Q >> K, N_KV >> K
|
||||
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||
and (0, n) is an all-zero matrix of shape (m, n).
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
||||
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
|
||||
batch and K is the maximum LoRA rank. The second dimension is partitioned
|
||||
for Q, K, and V.
|
||||
weights (Tensor): The LoRA B weights for all adapters.
|
||||
Shape: (num_lora, N_Q + 2 * N_KV, K).
|
||||
output (Tensor): The output tensor where the result is stored.
|
||||
Shape: (s, N_Q + 2 * N_KV).
|
||||
"""
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len.
|
||||
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
||||
batch_id = tl.program_id(axis=2)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
|
||||
# If rank is 0, this kernel is a no-op.
|
||||
if rank == 0:
|
||||
return
|
||||
|
||||
qkv_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = tl.load(n_offs + qkv_id)
|
||||
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
K = tl.minimum(K, rank)
|
||||
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_qkv_lora_b_kernel[grid_b](
|
||||
x,
|
||||
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
batch_info.scalings,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Computes a segmented batched matrix multiplication for the LoRA A matrix.
|
||||
|
||||
# x: (s, K), s is the sum of sequence lengths
|
||||
# weights: (num_lora, N, K)
|
||||
# output: (s, N)
|
||||
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
|
||||
stores the product of the input `x` and the LoRA weights for the corresponding
|
||||
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
||||
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
||||
is the sum of all sequence lengths in the batch.
|
||||
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
|
||||
with shape `(num_lora, N, K)`.
|
||||
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
||||
"""
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len
|
||||
batch_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
|
||||
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
||||
if rank == 0:
|
||||
return
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
|
||||
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
||||
N = tl.minimum(N, rank * stack_num)
|
||||
|
||||
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
|
||||
output_ptr = (output + seg_start * output_stride_0) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
|
||||
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
|
||||
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
# For fused output scaling
|
||||
scalings,
|
||||
):
|
||||
# x: (s, K), s is the sum of sequence lengths
|
||||
# weights: (num_lora, N, K)
|
||||
# output: (s, N)
|
||||
"""
|
||||
Computes a segmented batched matrix multiplication for the LoRA B matrix
|
||||
and adds the result to the output in-place.
|
||||
|
||||
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||
and (0, n) is an all-zero matrix of shape (m, n).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
|
||||
of shape `(s, K)`, where `s` is the total number of tokens.
|
||||
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
|
||||
with shape `(num_lora, N, K)`.
|
||||
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
|
||||
the base model's output for a fused add operation.
|
||||
"""
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len
|
||||
batch_id = tl.program_id(axis=1)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
|
||||
# If rank is 0, this kernel is a no-op.
|
||||
if rank == 0:
|
||||
return
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
K = tl.minimum(K, rank)
|
||||
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = s_offset[:, None] < seg_len
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_sgemm_lora_b_kernel[grid](
|
||||
x,
|
||||
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
|
||||
BLOCK_S,
|
||||
BLOCK_N,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
batch_info.scalings,
|
||||
)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user