[2/2] Introduce Chunked-SGMV kernels and corresponding LoRA backend for improved performance (#10286)

This commit is contained in:
Lifu Huang
2025-09-15 16:04:03 -07:00
committed by GitHub
parent 2689f0bf02
commit 3f41b48c40
10 changed files with 1499 additions and 13 deletions

View File

@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
# elif name == "csgmv":
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
elif name == "csgmv":
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
# return ChunkedSgmvLoRABackend
return ChunkedSgmvLoRABackend
elif name == "flashinfer":
raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."

View File

@@ -0,0 +1,306 @@
from typing import Optional
import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ChunkedSgmvLoRABackend(BaseLoRABackend):
"""
Chunked LoRA backend using segmented matrix-vector multiplication.
This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
when the LoRA distribution is skewed.
"""
name = "csgmv"
def __init__(self, max_loras_per_batch: int, device: torch.device):
super().__init__(max_loras_per_batch, device)
self.segment_size = 16 # TODO (lifuhuang): make it configurable?
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return chunked_sgmv_lora_shrink_forward(
x,
weights,
self.batch_info,
)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# For simple lora B, we use slice offsets [0, output_dim]
output_dim = weights.shape[-2]
max_slice_size = output_dim
return chunked_sgmv_lora_expand_forward(
x=x,
lora_weight_b=weights,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_slice_size,
base_output=base_output,
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
qkv_lora_a,
self.batch_info,
num_slices=3,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=qkv_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_qkv_out_dim,
base_output=base_output,
)
return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert isinstance(gate_up_lora_b, torch.Tensor)
output_dim = gate_up_lora_b.shape[-2] // 2
# lora_a_output: (s, 2 * r)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
gate_up_lora_a,
self.batch_info,
num_slices=2,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=gate_up_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=output_dim,
base_output=base_output,
)
return lora_output
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
weight_indices, forward_batch
)
seg_weight_indices, seg_indptr = self._get_segments_info(
weight_indices_reordered
)
num_segments = len(seg_weight_indices)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
if batch_info is None:
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=num_segments,
use_cuda_graph=False,
seg_indptr=torch.empty(
(num_segments + 1,), dtype=torch.int32, device=self.device
),
weight_indices=torch.empty(
(num_segments,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=torch.empty(
(len(permutation),), dtype=torch.int32, device=self.device
),
# Not used in chunked kernels
max_len=None,
seg_lens=None,
)
else:
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = num_segments
# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:num_segments].copy_(
seg_weight_indices, non_blocking=True
)
batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
self.batch_info = batch_info
@staticmethod
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
"""
Computes permutation indices for reordering tokens by their LoRA adapter assignments.
This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
multiplication by creating a permutation that groups tokens by their LoRA adapter.
Tokens using the same LoRA adapter are placed together to enable efficient batched
computation.
Example:
seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
# Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
# Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
# weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
Args:
seq_weight_indices: List of LoRA adapter indices for each sequence
forward_batch (ForwardBatch): Batch information containing sequence lengths
Returns:
tuple: (permutation, weights_reordered) where:
- permutation: Token reordering indices to group by adapter
- weights_reordered: Sorted adapter indices for each token
"""
with torch.device("cpu"):
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
seg_lens_cpu = (
torch.tensor(
forward_batch.extend_seq_lens_cpu,
dtype=torch.int32,
)
if forward_batch.forward_mode.is_extend()
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
)
row_weight_indices = torch.repeat_interleave(
seq_weight_indices, seg_lens_cpu
)
permutation = torch.empty(
(len(row_weight_indices),), dtype=torch.long, pin_memory=True
)
torch.argsort(row_weight_indices, stable=True, out=permutation)
weights_reordered = row_weight_indices[permutation]
return permutation, weights_reordered
def _get_segments_info(self, weights_reordered: torch.Tensor):
"""
Computes segment information for chunked SGMV operations.
This function takes the reordered weight indices and creates segments of fixed size
(self.segment_size) for efficient kernel execution. Each segment contains tokens
that use the same LoRA adapter, enabling vectorized computation.
The segmentation is necessary because:
1. GPU kernels work efficiently on fixed-size blocks
2. Large groups of tokens using the same adapter are split into manageable chunks
3. Each segment can be processed independently in parallel
Example:
weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
segment_size = 3
# Creates segments:
# Segment 0: tokens 0-2 (adapter 0), length=3
# Segment 1: tokens 3-4 (adapter 0), length=2
# Segment 2: token 5 (adapter 1), length=1
# Returns:
# weight_indices_list: [0, 0, 1] (adapter for each segment)
# seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
Args:
weights_reordered (torch.Tensor): Sorted adapter indices for each token
Returns:
tuple: (weight_indices_list, seg_indptr) where:
- weight_indices_list: LoRA adapter index for each segment
- seg_indptr: Cumulative segment boundaries (CSR-style indptr)
"""
with torch.device("cpu"):
unique_weights, counts = torch.unique_consecutive(
weights_reordered, return_counts=True
)
weight_indices_list = []
seg_lens_list = []
for weight_idx, group_len in zip(unique_weights, counts):
group_len = group_len.item()
num_segs = (group_len + self.segment_size - 1) // self.segment_size
weight_indices_list.extend([weight_idx.item()] * num_segs)
seg_lens_list.extend([self.segment_size] * (num_segs - 1))
seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size)
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
weight_indices_list = torch.tensor(
weight_indices_list, dtype=torch.int32, pin_memory=True
)
seg_indptr = torch.empty(
(len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
)
seg_indptr[0] = 0
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
return weight_indices_list, seg_indptr

View File

@@ -28,14 +28,15 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
logger = logging.getLogger(__name__)
SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
class LoRALayer(nn.Module):
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module):
def __init__(
self,
uid: str,
@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name])
assert isinstance(self.lora_backend, TritonLoRABackend), (
f"LoRA weight initialization currently only supported for 'triton' backend. "
assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)

View File

@@ -1,3 +1,5 @@
from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
from .gate_up_lora_b import gate_up_lora_b_fwd
from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd
@@ -8,4 +10,6 @@ __all__ = [
"qkv_lora_b_fwd",
"sgemm_lora_a_fwd",
"sgemm_lora_b_fwd",
"chunked_sgmv_lora_shrink_forward",
"chunked_sgmv_lora_expand_forward",
]

View File

@@ -0,0 +1,211 @@
from typing import Optional
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _chunked_lora_expand_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_indptr,
weight_indices,
lora_ranks,
permutation,
num_segs,
# For fused output scaling
scalings,
# Offsets of q/k/v slice on output dimension
slice_offsets,
# Meta parameters
NUM_SLICES: tl.constexpr,
MAX_RANK: tl.constexpr, # K = R
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Computes a chunked SGMV for LoRA expand operations.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, output_dim, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, output_dim).
"""
tl.static_assert(NUM_SLICES <= 3)
pid_s = tl.program_id(axis=2)
if pid_s >= num_segs:
return
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
w_index = tl.load(weight_indices + pid_s)
cur_rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel is a no-op.
if cur_rank == 0:
return
seg_start = tl.load(seg_indptr + pid_s)
seg_end = tl.load(seg_indptr + pid_s + 1)
slice_id = tl.program_id(axis=1)
slice_start = tl.load(slice_offsets + slice_id)
slice_end = tl.load(slice_offsets + slice_id + 1)
scaling = tl.load(scalings + w_index)
# Adjust K (rank) according to the specific LoRA adapter
cur_rank = tl.minimum(MAX_RANK, cur_rank)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
# The pointers will be advanced as we move in the K direction
# and accumulate
pid_n = tl.program_id(axis=0)
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (
x
+ slice_id * cur_rank * x_stride_1
+ (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset_logical[:, None] < seg_end)
& (k_offset[None, :] < cur_rank - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
& (n_offset[None, :] < slice_end),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = output + (
s_offset_physical[:, None] * output_stride_0
+ n_offset[None, :] * output_stride_1
)
output_mask = (s_offset_logical[:, None] < seg_end) & (
n_offset[None, :] < slice_end
)
partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
tl.store(output_ptr, partial_sum, mask=output_mask)
def chunked_sgmv_lora_expand_forward(
x: torch.Tensor,
lora_weight_b: torch.Tensor,
batch_info: LoRABatchInfo,
slice_offsets: torch.Tensor,
max_slice_size: int,
base_output: torch.Tensor = None,
) -> torch.Tensor:
# x: (s, slice_num * r)
# lora_weight_b: (num_lora, output_dim, r)
# slice_offsets: boundaries for different slices in the output dimension
# output: (s, output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# For each slice i, accumulates:
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
max_lora_rank = lora_weight_b.shape[-1]
output_dim = lora_weight_b.shape[-2]
num_slices = len(slice_offsets) - 1
assert input_dim == num_slices * max_lora_rank
# TODO (lifuhuang): fine-tune per operation
BLOCK_M = 16
BLOCK_K = 16
BLOCK_N = 64
num_segments = batch_info.num_segments
grid = (
triton.cdiv(max_slice_size, BLOCK_N),
num_slices, # number of slices in the input/output
batch_info.bs if batch_info.use_cuda_graph else num_segments,
)
if base_output is None:
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
else:
output = base_output
_chunked_lora_expand_kernel[grid](
x=x,
weights=lora_weight_b,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=lora_weight_b.stride(0),
w_stride_1=lora_weight_b.stride(1),
w_stride_2=lora_weight_b.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
permutation=batch_info.permutation,
num_segs=num_segments,
scalings=batch_info.scalings,
slice_offsets=slice_offsets,
# constants
NUM_SLICES=num_slices,
MAX_RANK=max_lora_rank,
BLOCK_S=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output

View File

@@ -0,0 +1,177 @@
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _chunked_lora_shrink_kernel(
# Pointers to matrices
x,
weights,
output,
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths,ranks and weight id
seg_indptr,
weight_indices,
lora_ranks,
permutation,
num_segs,
# Meta parameters
N: tl.constexpr, # num_slices * r
K: tl.constexpr, # input_dim
NUM_SLICES: tl.constexpr,
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Computes a chunked SGMV for LoRA shrink operations.
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
stores the product of the input `x` and the LoRA weights for the corresponding
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
Args:
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
is the sum of all sequence lengths in the batch.
weights (torch.Tensor): The LoRA A weights for all available adapters,
with shape `(num_lora, N, K)` where N = num_slices * r.
output (torch.Tensor): The output tensor of shape `(s, N)`.
"""
pid_s = tl.program_id(1)
if pid_s >= num_segs:
return
pid_n = tl.program_id(0)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
w_index = tl.load(weight_indices + pid_s)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
if rank == 0:
return
seg_start = tl.load(seg_indptr + pid_s)
seg_end = tl.load(seg_indptr + pid_s + 1)
# Adjust N dim according to the specific LoRA adapter
cur_n = tl.minimum(N, rank * NUM_SLICES)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = x + (
s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset_logical[:, None] < seg_end)
& (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = output + (
s_offset_physical[:, None] * output_stride_0
+ n_offset[None, :] * output_stride_1
)
output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
tl.store(output_ptr, partial_sum, mask=output_mask)
def chunked_sgmv_lora_shrink_forward(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoRABatchInfo,
num_slices: int = 1,
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, num_slices * r, input_dim)
# output: (s, num_slices * r)
# num_slices: qkv=3, gate_up=2, others=1
# when called with multiple slices, the weights.shape[-2] will be num_slices * r
# input_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
# Block shapes
# TODO (lifuhuang): experiment with split-k
BLOCK_S = 16
BLOCK_N = 16
BLOCK_K = 256
S = x.shape[0]
N = weights.shape[1]
K = weights.shape[2]
assert x.shape[-1] == K
num_segments = batch_info.num_segments
grid = (
triton.cdiv(N, BLOCK_N),
batch_info.bs if batch_info.use_cuda_graph else num_segments,
)
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
_chunked_lora_shrink_kernel[grid](
x=x,
weights=weights,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=weights.stride(0),
w_stride_1=weights.stride(1),
w_stride_2=weights.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
permutation=batch_info.permutation,
num_segs=num_segments,
# constants
N=N,
K=K,
NUM_SLICES=num_slices,
BLOCK_S=BLOCK_S,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output

View File

@@ -110,6 +110,8 @@ ATTENTION_BACKEND_CHOICES = [
"ascend",
]
LORA_BACKEND_CHOICES = ["triton", "csgmv"]
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
@@ -1601,7 +1603,8 @@ class ServerArgs:
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
choices=LORA_BACKEND_CHOICES,
default=ServerArgs.lora_backend,
help="Choose the kernel backend for multi-LoRA serving.",
)