[LoRA, Performance] Speedup multi-LoRA serving - Step 1 (#1587)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
NUM_LORAS = 128
|
NUM_LORAS = 8
|
||||||
LORA_PATH = {
|
LORA_PATH = {
|
||||||
"base": "mistralai/Mistral-7B-Instruct-v0.3",
|
"base": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
"lora": "/home/ying/test_lora",
|
"lora": "/home/ying/test_lora",
|
||||||
@@ -11,12 +11,11 @@ LORA_PATH = {
|
|||||||
def launch_server(args):
|
def launch_server(args):
|
||||||
base_path = LORA_PATH["base"]
|
base_path = LORA_PATH["base"]
|
||||||
lora_path = LORA_PATH["lora"]
|
lora_path = LORA_PATH["lora"]
|
||||||
max_loras_per_batch = 4
|
|
||||||
|
|
||||||
if args.base_only:
|
if args.base_only:
|
||||||
cmd = f"python -m sglang.launch_server --model {base_path} "
|
cmd = f"python3 -m sglang.launch_server --model {base_path} "
|
||||||
else:
|
else:
|
||||||
cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths "
|
cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
|
||||||
for i in range(NUM_LORAS):
|
for i in range(NUM_LORAS):
|
||||||
lora_name = f"lora{i}"
|
lora_name = f"lora{i}"
|
||||||
cmd += f"{lora_name}={lora_path} "
|
cmd += f"{lora_name}={lora_path} "
|
||||||
@@ -29,11 +28,6 @@ def launch_server(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
|
||||||
"--num-loras",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-only",
|
"--base-only",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||||
|
|
||||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer = A_buffer
|
self.A_buffer = A_buffer
|
||||||
self.B_buffer = B_buffer
|
self.B_buffer = B_buffer
|
||||||
self.bs = bs
|
self.bs = bs
|
||||||
self.seq_lens = seq_lens
|
self.seg_indptr = seg_indptr
|
||||||
self.weight_indices = weight_indices
|
self.weight_indices = weight_indices
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
weights=self.A_buffer,
|
weights=self.A_buffer,
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
# FIXME
|
# FIXME
|
||||||
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
|
||||||
lora_output = torch.empty_like(base_output)
|
lora_output = torch.empty_like(base_output)
|
||||||
output_dim = lora_output.shape[-1] // 2
|
output_dim = lora_output.shape[-1] // 2
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
return base_output + lora_output * self.scaling
|
return base_output + lora_output * self.scaling
|
||||||
@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||||
|
|
||||||
def set_lora_info(
|
def set_lora_info(
|
||||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
|
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
|
||||||
):
|
):
|
||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer_qkv = A_buffer_qkv
|
self.A_buffer_qkv = A_buffer_qkv
|
||||||
self.B_buffer_q = B_buffer_q
|
self.B_buffer_q = B_buffer_q
|
||||||
self.B_buffer_kv = B_buffer_kv
|
self.B_buffer_kv = B_buffer_kv
|
||||||
self.bs = bs
|
self.bs = bs
|
||||||
self.seq_lens = seq_lens
|
self.seg_indptr = seg_indptr
|
||||||
self.weight_indices = weight_indices
|
self.weight_indices = weight_indices
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
weights=self.A_buffer_qkv,
|
weights=self.A_buffer_qkv,
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
# FIXME parallelize qkv
|
# FIXME parallelize qkv
|
||||||
@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
weights=self.B_buffer_q,
|
weights=self.B_buffer_q,
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
# kv
|
# kv
|
||||||
@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||||
|
|
||||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer = A_buffer
|
self.A_buffer = A_buffer
|
||||||
self.B_buffer = B_buffer
|
self.B_buffer = B_buffer
|
||||||
self.bs = bs
|
self.bs = bs
|
||||||
self.seq_lens = seq_lens
|
self.seg_indptr = seg_indptr
|
||||||
self.weight_indices = weight_indices
|
self.weight_indices = weight_indices
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
weights=self.A_buffer,
|
weights=self.A_buffer,
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
lora_output = self.segment_gemm.run(
|
lora_output = self.segment_gemm.run(
|
||||||
@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
weights=self.B_buffer,
|
weights=self.B_buffer,
|
||||||
batch_size=self.bs,
|
batch_size=self.bs,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_lens=self.seq_lens,
|
seg_indptr=self.seg_indptr,
|
||||||
weight_indices=self.weight_indices,
|
weight_indices=self.weight_indices,
|
||||||
)
|
)
|
||||||
return base_output + lora_output * self.scaling
|
return base_output + lora_output * self.scaling
|
||||||
|
|||||||
@@ -274,18 +274,24 @@ class LoRAManager:
|
|||||||
cur_uids = set(forward_batch.lora_paths)
|
cur_uids = set(forward_batch.lora_paths)
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
i = 0
|
i = 0
|
||||||
|
j = len(self.active_uids)
|
||||||
evictable_uids = list(self.active_uids)
|
evictable_uids = list(self.active_uids)
|
||||||
for uid in cur_uids:
|
for uid in cur_uids:
|
||||||
if uid not in self.active_uids:
|
if uid not in self.active_uids:
|
||||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
if j < self.max_loras_per_batch:
|
||||||
i += 1
|
index = j
|
||||||
if i < len(evictable_uids):
|
j += 1
|
||||||
|
else:
|
||||||
|
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
||||||
|
i += 1
|
||||||
|
assert i < len(evictable_uids)
|
||||||
self.active_uids.remove(evictable_uids[i])
|
self.active_uids.remove(evictable_uids[i])
|
||||||
self.buffer_id.pop(evictable_uids[i])
|
self.buffer_id.pop(evictable_uids[i])
|
||||||
self.load_lora(uid, i)
|
index = i
|
||||||
|
i += 1
|
||||||
|
self.load_lora(uid, index)
|
||||||
self.active_uids.add(uid)
|
self.active_uids.add(uid)
|
||||||
self.buffer_id[uid] = i
|
self.buffer_id[uid] = index
|
||||||
i += 1
|
|
||||||
|
|
||||||
if cur_uids == set([None]):
|
if cur_uids == set([None]):
|
||||||
return
|
return
|
||||||
@@ -295,8 +301,11 @@ class LoRAManager:
|
|||||||
seg_lens = (
|
seg_lens = (
|
||||||
forward_batch.extend_seq_lens
|
forward_batch.extend_seq_lens
|
||||||
if forward_batch.forward_mode.is_extend()
|
if forward_batch.forward_mode.is_extend()
|
||||||
else torch.ones(bs)
|
else torch.ones(bs, device="cuda")
|
||||||
)
|
)
|
||||||
|
# FIXME: reuse the data rather than recompute
|
||||||
|
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||||
weight_indices[i] = self.buffer_id[lora_path]
|
weight_indices[i] = self.buffer_id[lora_path]
|
||||||
@@ -310,7 +319,7 @@ class LoRAManager:
|
|||||||
self.A_buffer[weight_name][layer_id],
|
self.A_buffer[weight_name][layer_id],
|
||||||
self.B_buffer[weight_name][layer_id],
|
self.B_buffer[weight_name][layer_id],
|
||||||
bs,
|
bs,
|
||||||
seg_lens,
|
seg_indptr,
|
||||||
weight_indices,
|
weight_indices,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -319,6 +328,6 @@ class LoRAManager:
|
|||||||
self.B_buffer["q_proj"][layer_id],
|
self.B_buffer["q_proj"][layer_id],
|
||||||
self.B_buffer["kv_proj"][layer_id],
|
self.B_buffer["kv_proj"][layer_id],
|
||||||
bs,
|
bs,
|
||||||
seg_lens,
|
seg_indptr,
|
||||||
weight_indices,
|
weight_indices,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user