[LoRA, Performance] Speedup multi-LoRA serving - Step 1 (#1587)

This commit is contained in:
Ying Sheng
2024-10-06 10:33:44 -07:00
committed by GitHub
parent 58d1082e39
commit 9c064bf78a
3 changed files with 34 additions and 32 deletions

View File

@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
assert lora_a_output.shape[-1] == self.lora_rank * 2
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME parallelize qkv
@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# kv
@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
lora_output = self.segment_gemm.run(
@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling