diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index 1fa4d7135..f139f0df6 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -1,7 +1,7 @@ import argparse import os -NUM_LORAS = 128 +NUM_LORAS = 8 LORA_PATH = { "base": "mistralai/Mistral-7B-Instruct-v0.3", "lora": "/home/ying/test_lora", @@ -11,12 +11,11 @@ LORA_PATH = { def launch_server(args): base_path = LORA_PATH["base"] lora_path = LORA_PATH["lora"] - max_loras_per_batch = 4 if args.base_only: - cmd = f"python -m sglang.launch_server --model {base_path} " + cmd = f"python3 -m sglang.launch_server --model {base_path} " else: - cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths " + cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths " for i in range(NUM_LORAS): lora_name = f"lora{i}" cmd += f"{lora_name}={lora_path} " @@ -29,11 +28,6 @@ def launch_server(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--num-loras", - type=int, - default=128, - ) parser.add_argument( "--base-only", action="store_true", diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 379b233bd..85470996f 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ) -> None: super().__init__(base_layer, segment_gemm, lora_rank, scaling) - def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer self.bs = bs - self.seq_lens = seq_lens + self.seg_indptr = seg_indptr self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): weights=self.A_buffer, batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) # FIXME - assert lora_a_output.shape[-1] == self.lora_rank * 2 lora_output = torch.empty_like(base_output) output_dim = lora_output.shape[-1] // 2 for i in range(2): @@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): weights=self.B_buffer[:, left:right, :].contiguous(), batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) return base_output + lora_output * self.scaling @@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): super().__init__(base_layer, segment_gemm, lora_rank, scaling) def set_lora_info( - self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices + self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv self.B_buffer_q = B_buffer_q self.B_buffer_kv = B_buffer_kv self.bs = bs - self.seq_lens = seq_lens + self.seg_indptr = seg_indptr self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): weights=self.A_buffer_qkv, batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) # FIXME parallelize qkv @@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): weights=self.B_buffer_q, batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) # kv @@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): weights=self.B_buffer_kv[:, left:right, :].contiguous(), batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) ) @@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ) -> None: super().__init__(base_layer, segment_gemm, lora_rank, scaling) - def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer self.bs = bs - self.seq_lens = seq_lens + self.seg_indptr = seg_indptr self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): weights=self.A_buffer, batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) lora_output = self.segment_gemm.run( @@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): weights=self.B_buffer, batch_size=self.bs, weight_column_major=True, - seg_lens=self.seq_lens, + seg_indptr=self.seg_indptr, weight_indices=self.weight_indices, ) return base_output + lora_output * self.scaling diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 59cd7e157..dd46212ed 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -274,18 +274,24 @@ class LoRAManager: cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch i = 0 + j = len(self.active_uids) evictable_uids = list(self.active_uids) for uid in cur_uids: if uid not in self.active_uids: - while i < len(evictable_uids) and evictable_uids[i] in cur_uids: - i += 1 - if i < len(evictable_uids): + if j < self.max_loras_per_batch: + index = j + j += 1 + else: + while i < len(evictable_uids) and evictable_uids[i] in cur_uids: + i += 1 + assert i < len(evictable_uids) self.active_uids.remove(evictable_uids[i]) self.buffer_id.pop(evictable_uids[i]) - self.load_lora(uid, i) + index = i + i += 1 + self.load_lora(uid, index) self.active_uids.add(uid) - self.buffer_id[uid] = i - i += 1 + self.buffer_id[uid] = index if cur_uids == set([None]): return @@ -295,8 +301,11 @@ class LoRAManager: seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() - else torch.ones(bs) + else torch.ones(bs, device="cuda") ) + # FIXME: reuse the data rather than recompute + seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.buffer_id[lora_path] @@ -310,7 +319,7 @@ class LoRAManager: self.A_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id], bs, - seg_lens, + seg_indptr, weight_indices, ) else: @@ -319,6 +328,6 @@ class LoRAManager: self.B_buffer["q_proj"][layer_id], self.B_buffer["kv_proj"][layer_id], bs, - seg_lens, + seg_indptr, weight_indices, )