[Feature] Define backends and add Triton backend for Lora (#3161)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -18,8 +18,8 @@
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraBatchInfo:
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
||||
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.segment_gemm = segment_gemm
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
self.lora_backend = lora_backend
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer,
|
||||
B_buffer,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
||||
|
||||
output_dim = base_output.shape[-1]
|
||||
lora_output = torch.empty_like(base_output)
|
||||
output_dim = lora_output.shape[-1] // 2
|
||||
for i in range(2):
|
||||
left = output_dim * i
|
||||
right = left + output_dim
|
||||
lora_output[:, left:right] = self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[0],
|
||||
)
|
||||
|
||||
lora_output[:, output_dim : 2 * output_dim] = (
|
||||
self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[1],
|
||||
)
|
||||
)
|
||||
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
||||
def init__(
|
||||
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
|
||||
self,
|
||||
A_buffer_qkv,
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
self.B_buffer_q = B_buffer_q
|
||||
self.B_buffer_kv = B_buffer_kv
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
if self.lora_backend.fuse_qkv_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
self.B_buffer_qkv = torch.cat(
|
||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||
).contiguous()
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
output_dim_q + output_dim_kv,
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=B_buffer_q.device,
|
||||
)
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
else:
|
||||
self.B_buffer_qkv = (
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
)
|
||||
self.output_offset = None
|
||||
self.max_qkv_out_dim = None
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer_qkv,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_output = self.lora_backend.run_qkv_lora(
|
||||
x,
|
||||
self.A_buffer_qkv,
|
||||
self.B_buffer_qkv,
|
||||
output_offset=self.output_offset,
|
||||
max_qkv_out_dim=self.max_qkv_out_dim,
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
# FIXME parallelize qkv
|
||||
lora_output = torch.empty_like(base_output)
|
||||
# q
|
||||
output_dim_q = self.B_buffer_q.shape[-2]
|
||||
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
||||
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer_q,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
# kv
|
||||
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
||||
for i in range(2):
|
||||
left = output_dim_kv * i
|
||||
right = left + output_dim_kv
|
||||
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
||||
self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||
def set_lora_info(self, A_buffer, B_buffer):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=lora_output,
|
||||
weights=self.B_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, segment_gemm, lora_rank, scaling
|
||||
layer: nn.Module, lora_rank, scaling, lora_backend
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
@@ -267,7 +276,7 @@ def get_lora_layer(
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config):
|
||||
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.lora_backend = lora_backend
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[kv_name] = torch.stack(
|
||||
[
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
layer.weights[gate_up_name] = torch.stack(
|
||||
[layer.weights[weight_name], layer.weights[up_name]], dim=0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
|
||||
Reference in New Issue
Block a user