Fix division-by-zero bug in LoRA triton kernels. (#7785)
This commit is contained in:
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
|
|||||||
BLOCK_S: tl.constexpr,
|
BLOCK_S: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
# For fused output scaling and adding
|
# For fused output scaling
|
||||||
fuse_scaling_add,
|
|
||||||
scalings,
|
scalings,
|
||||||
):
|
):
|
||||||
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
"""
|
||||||
|
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
|
||||||
|
results are accumulated into the output tensor.
|
||||||
|
|
||||||
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||||
# weights: (num_lora, 2 * output_dim, K)
|
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||||
# output: (s, 2 * output_dim)
|
and (0, n) is an all-zero matrix of shape (m, n).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
||||||
|
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
|
||||||
|
batch and K is the maximum LoRA rank.
|
||||||
|
weights (Tensor): The LoRA B weights for all adapters.
|
||||||
|
Shape: (num_lora, 2 * output_dim, K).
|
||||||
|
output (Tensor): The output tensor where the result is stored.
|
||||||
|
Shape: (s, 2 * output_dim).
|
||||||
|
"""
|
||||||
# output_dim >> K
|
# output_dim >> K
|
||||||
|
|
||||||
# Current block computes sequence with batch_id,
|
# Current block computes sequence with batch_id,
|
||||||
# which starts from row seg_start of x with length seg_len.
|
# which starts from row seg_start of x with length seg_len.
|
||||||
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
||||||
batch_id = tl.program_id(axis=2)
|
batch_id = tl.program_id(axis=2)
|
||||||
|
w_index = tl.load(weight_indices + batch_id)
|
||||||
|
rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel is a no-op.
|
||||||
|
if rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
gate_up_id = tl.program_id(axis=1)
|
gate_up_id = tl.program_id(axis=1)
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
seg_len = tl.load(seg_lens + batch_id)
|
seg_len = tl.load(seg_lens + batch_id)
|
||||||
w_index = tl.load(weight_indices + batch_id)
|
|
||||||
seg_start = tl.load(seg_indptr + batch_id)
|
seg_start = tl.load(seg_indptr + batch_id)
|
||||||
n_start = gate_up_id * output_dim # offset on output dim
|
n_start = gate_up_id * output_dim # offset on output dim
|
||||||
rank = tl.load(lora_ranks + w_index)
|
|
||||||
scaling = tl.load(scalings + w_index)
|
scaling = tl.load(scalings + w_index)
|
||||||
|
|
||||||
# Adjust K (rank) according to the specific LoRA adapter
|
# Adjust K (rank) according to the specific LoRA adapter
|
||||||
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
x_tile = tl.load(
|
x_tile = tl.load(
|
||||||
x_ptrs,
|
x_ptrs,
|
||||||
mask=(s_offset[:, None] < seg_len)
|
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
w_tile = tl.load(
|
w_tile = tl.load(
|
||||||
w_ptrs,
|
w_ptrs,
|
||||||
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
||||||
and (n_offset[None, :] < output_dim),
|
& (n_offset[None, :] < output_dim),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
partial_sum += tl.dot(x_tile, w_tile)
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
@@ -103,8 +118,7 @@ def _gate_up_lora_b_kernel(
|
|||||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
)
|
)
|
||||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
|
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
|
||||||
if fuse_scaling_add:
|
|
||||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if base_output is None:
|
if base_output is None:
|
||||||
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
||||||
fuse_scaling_add = False
|
|
||||||
else:
|
else:
|
||||||
output = base_output
|
output = base_output
|
||||||
fuse_scaling_add = True
|
|
||||||
|
|
||||||
_gate_up_lora_b_kernel[grid_b](
|
_gate_up_lora_b_kernel[grid_b](
|
||||||
x,
|
x,
|
||||||
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
|
|||||||
BLOCK_S,
|
BLOCK_S,
|
||||||
BLOCK_OUT,
|
BLOCK_OUT,
|
||||||
BLOCK_R,
|
BLOCK_R,
|
||||||
fuse_scaling_add,
|
|
||||||
batch_info.scalings,
|
batch_info.scalings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
|
|||||||
BLOCK_S: tl.constexpr,
|
BLOCK_S: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
# For fused output scaling and adding
|
# For fused output scaling
|
||||||
fuse_scaling_add,
|
|
||||||
scalings,
|
scalings,
|
||||||
):
|
):
|
||||||
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
"""
|
||||||
|
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
|
||||||
|
results are accumulated into the output tensor.
|
||||||
|
|
||||||
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||||
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||||
# output: (s, N_Q + 2 * N_KV)
|
and (0, n) is an all-zero matrix of shape (m, n).
|
||||||
# N_Q >> K, N_KV >> K
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
||||||
|
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
|
||||||
|
batch and K is the maximum LoRA rank. The second dimension is partitioned
|
||||||
|
for Q, K, and V.
|
||||||
|
weights (Tensor): The LoRA B weights for all adapters.
|
||||||
|
Shape: (num_lora, N_Q + 2 * N_KV, K).
|
||||||
|
output (Tensor): The output tensor where the result is stored.
|
||||||
|
Shape: (s, N_Q + 2 * N_KV).
|
||||||
|
"""
|
||||||
|
|
||||||
# Current block computes sequence with batch_id,
|
# Current block computes sequence with batch_id,
|
||||||
# which starts from row seg_start of x with length seg_len.
|
# which starts from row seg_start of x with length seg_len.
|
||||||
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
||||||
batch_id = tl.program_id(axis=2)
|
batch_id = tl.program_id(axis=2)
|
||||||
|
w_index = tl.load(weight_indices + batch_id)
|
||||||
|
rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel is a no-op.
|
||||||
|
if rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
qkv_id = tl.program_id(axis=1)
|
qkv_id = tl.program_id(axis=1)
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
seg_len = tl.load(seg_lens + batch_id)
|
seg_len = tl.load(seg_lens + batch_id)
|
||||||
w_index = tl.load(weight_indices + batch_id)
|
|
||||||
seg_start = tl.load(seg_indptr + batch_id)
|
seg_start = tl.load(seg_indptr + batch_id)
|
||||||
n_start = tl.load(n_offs + qkv_id)
|
n_start = tl.load(n_offs + qkv_id)
|
||||||
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
||||||
rank = tl.load(lora_ranks + w_index)
|
|
||||||
scaling = tl.load(scalings + w_index)
|
scaling = tl.load(scalings + w_index)
|
||||||
# Adjust K (rank) according to the specific LoRA adapter
|
# Adjust K (rank) according to the specific LoRA adapter
|
||||||
K = tl.minimum(K, rank)
|
K = tl.minimum(K, rank)
|
||||||
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
x_tile = tl.load(
|
x_tile = tl.load(
|
||||||
x_ptrs,
|
x_ptrs,
|
||||||
mask=(s_offset[:, None] < seg_len)
|
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
w_tile = tl.load(
|
w_tile = tl.load(
|
||||||
w_ptrs,
|
w_ptrs,
|
||||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
|
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
partial_sum += tl.dot(x_tile, w_tile)
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
@@ -105,7 +120,6 @@ def _qkv_lora_b_kernel(
|
|||||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
)
|
)
|
||||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
||||||
if fuse_scaling_add:
|
|
||||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if base_output is None:
|
if base_output is None:
|
||||||
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
|
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
|
||||||
fuse_scaling_add = False
|
|
||||||
else:
|
else:
|
||||||
output = base_output
|
output = base_output
|
||||||
fuse_scaling_add = True
|
|
||||||
|
|
||||||
_qkv_lora_b_kernel[grid_b](
|
_qkv_lora_b_kernel[grid_b](
|
||||||
x,
|
x,
|
||||||
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
|
|||||||
BLOCK_S,
|
BLOCK_S,
|
||||||
BLOCK_OUT,
|
BLOCK_OUT,
|
||||||
BLOCK_R,
|
BLOCK_R,
|
||||||
fuse_scaling_add,
|
|
||||||
batch_info.scalings,
|
batch_info.scalings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
|
|||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Computes a segmented batched matrix multiplication for the LoRA A matrix.
|
||||||
|
|
||||||
# x: (s, K), s is the sum of sequence lengths
|
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
|
||||||
# weights: (num_lora, N, K)
|
stores the product of the input `x` and the LoRA weights for the corresponding
|
||||||
# output: (s, N)
|
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
||||||
|
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
||||||
|
is the sum of all sequence lengths in the batch.
|
||||||
|
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
|
||||||
|
with shape `(num_lora, N, K)`.
|
||||||
|
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
||||||
|
"""
|
||||||
|
|
||||||
# Current block computes sequence with batch_id,
|
# Current block computes sequence with batch_id,
|
||||||
# which starts from row seg_start of x with length seg_len
|
# which starts from row seg_start of x with length seg_len
|
||||||
batch_id = tl.program_id(axis=1)
|
batch_id = tl.program_id(axis=1)
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
seg_len = tl.load(seg_lens + batch_id)
|
|
||||||
w_index = tl.load(weight_indices + batch_id)
|
w_index = tl.load(weight_indices + batch_id)
|
||||||
seg_start = tl.load(seg_indptr + batch_id)
|
|
||||||
rank = tl.load(lora_ranks + w_index)
|
rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
||||||
|
if rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
seg_start = tl.load(seg_indptr + batch_id)
|
||||||
|
seg_len = tl.load(seg_lens + batch_id)
|
||||||
|
|
||||||
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
||||||
N = tl.minimum(N, rank * stack_num)
|
N = tl.minimum(N, rank * stack_num)
|
||||||
|
|
||||||
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
x_tile = tl.load(
|
x_tile = tl.load(
|
||||||
x_ptrs,
|
x_ptrs,
|
||||||
mask=(s_offset[:, None] < seg_len)
|
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
w_tile = tl.load(
|
w_tile = tl.load(
|
||||||
w_ptrs,
|
w_ptrs,
|
||||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
|
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
partial_sum += tl.dot(x_tile, w_tile)
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
|
|||||||
output_ptr = (output + seg_start * output_stride_0) + (
|
output_ptr = (output + seg_start * output_stride_0) + (
|
||||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
)
|
)
|
||||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
|
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
|
||||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
|
|||||||
BLOCK_S: tl.constexpr,
|
BLOCK_S: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
# For fused output scaling and adding
|
# For fused output scaling
|
||||||
fuse_scaling_add,
|
|
||||||
scalings,
|
scalings,
|
||||||
):
|
):
|
||||||
# x: (s, K), s is the sum of sequence lengths
|
"""
|
||||||
# weights: (num_lora, N, K)
|
Computes a segmented batched matrix multiplication for the LoRA B matrix
|
||||||
# output: (s, N)
|
and adds the result to the output in-place.
|
||||||
|
|
||||||
|
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||||
|
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||||
|
and (0, n) is an all-zero matrix of shape (m, n).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
|
||||||
|
of shape `(s, K)`, where `s` is the total number of tokens.
|
||||||
|
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
|
||||||
|
with shape `(num_lora, N, K)`.
|
||||||
|
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
|
||||||
|
the base model's output for a fused add operation.
|
||||||
|
"""
|
||||||
|
|
||||||
# Current block computes sequence with batch_id,
|
# Current block computes sequence with batch_id,
|
||||||
# which starts from row seg_start of x with length seg_len
|
# which starts from row seg_start of x with length seg_len
|
||||||
batch_id = tl.program_id(axis=1)
|
batch_id = tl.program_id(axis=1)
|
||||||
|
w_index = tl.load(weight_indices + batch_id)
|
||||||
|
rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel is a no-op.
|
||||||
|
if rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
seg_len = tl.load(seg_lens + batch_id)
|
seg_len = tl.load(seg_lens + batch_id)
|
||||||
w_index = tl.load(weight_indices + batch_id)
|
|
||||||
seg_start = tl.load(seg_indptr + batch_id)
|
seg_start = tl.load(seg_indptr + batch_id)
|
||||||
rank = tl.load(lora_ranks + w_index)
|
|
||||||
scaling = tl.load(scalings + w_index)
|
scaling = tl.load(scalings + w_index)
|
||||||
# Adjust K (rank) according to the specific LoRA adapter
|
# Adjust K (rank) according to the specific LoRA adapter
|
||||||
K = tl.minimum(K, rank)
|
K = tl.minimum(K, rank)
|
||||||
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
x_tile = tl.load(
|
x_tile = tl.load(
|
||||||
x_ptrs,
|
x_ptrs,
|
||||||
mask=(s_offset[:, None] < seg_len)
|
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
w_tile = tl.load(
|
w_tile = tl.load(
|
||||||
@@ -95,7 +111,6 @@ def _sgemm_lora_b_kernel(
|
|||||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
)
|
)
|
||||||
output_mask = s_offset[:, None] < seg_len
|
output_mask = s_offset[:, None] < seg_len
|
||||||
if fuse_scaling_add:
|
|
||||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if base_output is None:
|
if base_output is None:
|
||||||
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
|
||||||
fuse_scaling_add = False
|
|
||||||
else:
|
else:
|
||||||
output = base_output
|
output = base_output
|
||||||
fuse_scaling_add = True
|
|
||||||
|
|
||||||
_sgemm_lora_b_kernel[grid](
|
_sgemm_lora_b_kernel[grid](
|
||||||
x,
|
x,
|
||||||
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
|
|||||||
BLOCK_S,
|
BLOCK_S,
|
||||||
BLOCK_N,
|
BLOCK_N,
|
||||||
BLOCK_R,
|
BLOCK_R,
|
||||||
fuse_scaling_add,
|
|
||||||
batch_info.scalings,
|
batch_info.scalings,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user