[Feature] Define backends and add Triton backend for Lora (#3161)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Baizhou Zhang
2025-02-03 22:09:13 -08:00
committed by GitHub
parent 7b5a374114
commit 70817a7eae
18 changed files with 1129 additions and 135 deletions

View File

@@ -0,0 +1,8 @@
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend
__all__ = [
"FlashInferLoraBackend",
"TritonLoraBackend",
]

View File

@@ -0,0 +1,95 @@
from typing import Tuple, Union
import torch
from sglang.srt.lora.lora import LoraBatchInfo
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoraBackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of scaling and adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
pass
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
pass
def set_batch_info(self, batch_info: LoraBatchInfo):
self.batch_info = batch_info

View File

@@ -0,0 +1,88 @@
from typing import Tuple
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
class FlashInferLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output

View File

@@ -0,0 +1,61 @@
import torch
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.triton_ops import (
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
sgemm_lora_b_fwd,
)
class TritonLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return sgemm_lora_a_fwd(x, weights, self.batch_info)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
lora_output = qkv_lora_b_fwd(
lora_a_output,
qkv_lora_b,
self.batch_info,
output_offset,
max_qkv_out_dim,
base_output,
scaling,
)
return lora_output

View File

@@ -18,8 +18,8 @@
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re
from dataclasses import dataclass
import torch
from torch import nn
@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
output_dim = base_output.shape[-1]
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
weights=self.B_buffer[0],
)
lora_output[:, output_dim : 2 * output_dim] = (
self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[1],
)
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
def init__(
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
# FIXME parallelize qkv
lora_output = torch.empty_like(base_output)
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
return base_output + lora_output * self.scaling
def forward(self, input_):
# duplicate the logic in RowParallelLinear
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
@@ -267,7 +276,7 @@ def get_lora_layer(
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[kv_name] = torch.stack(
[
layer.weights[weight_name],
layer.weights[v_name],
),
0,
],
dim=0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
if "lora_A" in weight_name:
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)

View File

@@ -20,16 +20,14 @@ import re
import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__)
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
@@ -77,6 +75,20 @@ def get_stacked_name(name):
return params_mapping.get(name, (name, name))
def get_backend_from_name(name):
backend_mapping = {
"triton": TritonLoraBackend,
"flashinfer": FlashInferLoraBackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
@@ -93,6 +105,7 @@ class LoRAManager:
max_loras_per_batch,
load_config,
dtype,
lora_backend,
):
self.base_model = base_model
self.lora_paths = lora_paths
@@ -101,8 +114,9 @@ class LoRAManager:
self.load_config = load_config
self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
backend_type = get_backend_from_name(lora_backend)
self.lora_backend = backend_type(lora_backend)
self.init_loras()
self.init_lora_memory_pool()
@@ -123,7 +137,7 @@ class LoRAManager:
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
@@ -162,7 +176,11 @@ class LoRAManager:
self.lora_id[name] = len(self.loras)
self.loras.append(
LoRAAdapter(
name, self.configs[name], self.base_hf_config, self.load_config
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
)
self.loras[-1].initialize_weights()
@@ -226,8 +244,9 @@ class LoRAManager:
self.B_buffer[module_B] = [
torch.empty(
(
c,
self.max_loras_per_batch,
hidden_dim_B * c,
hidden_dim_B,
self.max_lora_dim,
),
dtype=self.dtype,
@@ -263,7 +282,16 @@ class LoRAManager:
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
if c > 1:
for j in range(c):
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
weights[j]
)
else:
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
weights
)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
@@ -292,20 +320,30 @@ class LoRAManager:
if cur_uids == set([None]):
return
# setup lora in forward modules
# set up batch info shared by all lora moruldes
bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device="cuda")
)
# FIXME: reuse the data rather than recompute
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
batch_info = LoraBatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
@@ -314,16 +352,10 @@ class LoRAManager:
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_indptr,
weight_indices,
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_indptr,
weight_indices,
)

View File

@@ -0,0 +1,5 @@
from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd
from .sgemm_lora_b import sgemm_lora_b_fwd
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]

View File

@@ -0,0 +1,182 @@
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _qkv_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
K, # K = R
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Offsets of q/k/v slice on output dimension
n_offs,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
# N_Q >> K, N_KV >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
batch_id = tl.program_id(axis=2)
qkv_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = tl.load(n_offs + qkv_id)
n_size = tl.load(n_offs + qkv_id + 1) - n_start
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def qkv_lora_b_fwd(
x: torch.Tensor,
qkv_lora_b: torch.Tensor,
batch_info: LoraBatchInfo,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, 3 * r)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
# output_dim_q + 2 * output_dim_kv]
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
r = qkv_lora_b.shape[-1]
output_dim = qkv_lora_b.shape[-2]
assert input_dim == 3 * r
assert output_offset.shape[0] == 4
BLOCK_S = 16
BLOCK_R = 16
BLOCK_OUT = 64
grid_b = (
triton.cdiv(batch_info.max_len, BLOCK_S)
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
3, # this dimension decides current block computes on q, k or v
batch_info.bs,
)
if base_output is None:
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_qkv_lora_b_kernel[grid_b](
x,
qkv_lora_b,
output,
r,
max_qkv_out_dim,
x.stride(0),
x.stride(1),
qkv_lora_b.stride(0),
qkv_lora_b.stride(1),
qkv_lora_b.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
output_offset,
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output

View File

@@ -0,0 +1,143 @@
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_a_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # r
K, # input_dim
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_a_fwd(
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
# output: (s, r)
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
# input_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
R = weights.shape[-2]
K = weights.shape[-1]
assert x.shape[-1] == K
# Block shapes
BLOCK_S = 16
BLOCK_K = 256
BLOCK_R = 16
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
batch_info.bs,
)
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
_sgemm_lora_a_kernel[grid](
x,
weights,
output,
R,
K,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_R,
BLOCK_K,
)
return output

View File

@@ -0,0 +1,159 @@
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # output_dim
K, # r
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = s_offset[:, None] < seg_len
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_b_fwd(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoraBatchInfo,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, r)
# weights: (num_lora, output_dim, r)
# output: (s, output_dim)
# output_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
N = weights.shape[-2]
R = weights.shape[-1]
assert x.shape[-1] == R
# Block shapes
BLOCK_S = 16
BLOCK_R = 16
BLOCK_N = 256
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
batch_info.bs,
)
if base_output is None:
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_sgemm_lora_b_kernel[grid](
x,
weights,
output,
N,
R,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_N,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output

View File

@@ -530,6 +530,7 @@ class ModelRunner:
max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config,
dtype=self.dtype,
lora_backend=self.server_args.lora_backend,
)
logger.info("LoRA manager ready.")

View File

@@ -113,6 +113,7 @@ class ServerArgs:
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
# Kernel backend
attention_backend: Optional[str] = None
@@ -653,13 +654,19 @@ class ServerArgs:
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
help="Maximum number of adapters for a running batch, include base-only request.",
)
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
help="Choose the kernel backend for multi-LoRA serving.",
)
# Kernel backend

View File

@@ -272,6 +272,7 @@ class SRTRunner:
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
):
@@ -287,6 +288,7 @@ class SRTRunner:
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)