[Feature] Define backends and add Triton backend for Lora (#3161)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Baizhou Zhang
2025-02-03 22:09:13 -08:00
committed by GitHub
parent 7b5a374114
commit 70817a7eae
18 changed files with 1129 additions and 135 deletions

View File

@@ -18,8 +18,8 @@
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re
from dataclasses import dataclass
import torch
from torch import nn
@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
output_dim = base_output.shape[-1]
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
weights=self.B_buffer[0],
)
lora_output[:, output_dim : 2 * output_dim] = (
self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[1],
)
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
def init__(
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
# FIXME parallelize qkv
lora_output = torch.empty_like(base_output)
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
return base_output + lora_output * self.scaling
def forward(self, input_):
# duplicate the logic in RowParallelLinear
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
@@ -267,7 +276,7 @@ def get_lora_layer(
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[kv_name] = torch.stack(
[
layer.weights[weight_name],
layer.weights[v_name],
),
0,
],
dim=0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
if "lora_A" in weight_name:
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)