From 021f76e4f49861b2e9ea9ccff06a46d577e3c548 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Wed, 11 Jun 2025 16:18:57 -0700 Subject: [PATCH] [Perf] Refactor LoRAManager to eliminate stream syncs and redundant computations (#6994) --- python/sglang/srt/lora/lora_manager.py | 115 +++++++++++++++++-------- python/sglang/srt/lora/mem_pool.py | 9 +- 2 files changed, 84 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 45050df53..9d0295808 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -81,7 +81,7 @@ class LoRAManager: seg_indptr=torch.zeros( self.max_bs_in_cuda_graph + 1, dtype=torch.int32 ), - max_len=0, + max_len=1, weight_indices=torch.zeros( self.max_bs_in_cuda_graph, dtype=torch.int32 ), @@ -89,6 +89,17 @@ class LoRAManager: scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), ) + # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant + # across batches. + self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1) + torch.cumsum( + self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph], + dim=0, + out=self.cuda_graph_batch_info.seg_indptr[ + 1 : self.max_bs_in_cuda_graph + 1 + ], + ) + def init_loras(self): # Config of each LoRA adapter self.configs: Dict[str, LoRAConfig] = {} @@ -159,6 +170,45 @@ class LoRAManager: # set up batch info shared by all lora modules bs = forward_batch.batch_size + def transfer_adapter_info( + weight_indices_out: torch.Tensor, + lora_ranks_out: torch.Tensor, + scalings_out: torch.Tensor, + ): + """ + Transfer adapter metadata (weight indices, LoRA rank, scalings) from host + to device (CUDA) asynchronously. + """ + weight_indices = [0] * len(forward_batch.lora_paths) + lora_ranks = [0] * self.max_loras_per_batch + scalings = [0] * self.max_loras_per_batch + for i, lora_path in enumerate(forward_batch.lora_paths): + weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) + if lora_path is not None: + lora = self.loras[lora_path] + lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] + scalings[weight_indices[i]] = lora.scaling + + # Use pinned memory to avoid synchronizations during host-to-device transfer + weight_indices_tensor = torch.tensor( + weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" + ) + lora_ranks_tensor = torch.tensor( + lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" + ) + scalings_tensor = torch.tensor( + scalings, dtype=torch.float, pin_memory=True, device="cpu" + ) + + # Copy to device tensors asynchronously + weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True) + lora_ranks_out[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + scalings_out[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + if ( hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph @@ -166,51 +216,46 @@ class LoRAManager: ): # Do in-place updates when CUDA graph is enabled and the batch forward mode # could use CUDA graph. - self.cuda_graph_batch_info.bs = bs - self.cuda_graph_batch_info.seg_lens[:bs].fill_(1) - torch.cumsum( - self.cuda_graph_batch_info.seg_lens[:bs], - dim=0, - out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1], - ) - self.cuda_graph_batch_info.max_len = 1 - for i, lora_path in enumerate(forward_batch.lora_paths): - self.cuda_graph_batch_info.weight_indices[i] = ( - self.memory_pool.get_buffer_id(lora_path) - ) - if lora_path is not None: - lora = self.loras[lora_path] - self.cuda_graph_batch_info.lora_ranks[ - self.cuda_graph_batch_info.weight_indices[i] - ] = lora.config.hf_config["r"] - self.cuda_graph_batch_info.scalings[ - self.cuda_graph_batch_info.weight_indices[i] - ] = lora.scaling + transfer_adapter_info( + self.cuda_graph_batch_info.weight_indices, + self.cuda_graph_batch_info.lora_ranks, + self.cuda_graph_batch_info.scalings, + ) + + self.cuda_graph_batch_info.bs = bs + self.cuda_graph_batch_info.max_len = 1 batch_info = self.cuda_graph_batch_info else: + weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) + lora_ranks = torch.zeros( + (self.max_loras_per_batch,), dtype=torch.int64, device=self.device + ) + scalings = torch.zeros( + (self.max_loras_per_batch,), dtype=torch.float, device=self.device + ) + transfer_adapter_info( + weight_indices, + lora_ranks, + scalings, + ) + seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() else torch.ones(bs, device=self.device) ) + + max_len = ( + # Calculate max_len from the CPU copy to avoid D2H transfer. + max(forward_batch.extend_seq_lens_cpu) + if forward_batch.forward_mode.is_extend() + else 1 + ) + seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) - max_len = int(torch.max(seg_lens)) - weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) - lora_ranks = torch.zeros( - (self.max_loras_per_batch,), dtype=torch.int64, device="cuda" - ) - scalings = torch.zeros( - (self.max_loras_per_batch,), dtype=torch.float, device="cuda" - ) - for i, lora_path in enumerate(forward_batch.lora_paths): - weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) - if lora_path is not None: - lora = self.loras[lora_path] - lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] - scalings[weight_indices[i]] = lora.scaling batch_info = LoRABatchInfo( bs=bs, seg_lens=seg_lens, diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 8b8d21332..7e69c4aab 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -132,12 +132,13 @@ class LoRAMemoryPool: for buffer_id in range(self.max_loras_per_batch): # Prioritize empty slots if self.buffer_id_to_uid[buffer_id] == "": - return buffer_id, "" + return buffer_id for buffer_id in range(self.max_loras_per_batch): # Evict unneeded lora if self.buffer_id_to_uid[buffer_id] not in cur_uids: - return buffer_id, self.buffer_id_to_uid[buffer_id] + self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id]) + return buffer_id raise ValueError( "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch." @@ -145,9 +146,7 @@ class LoRAMemoryPool: for uid in cur_uids: if uid not in self.uid_to_buffer_id: - buffer_id, evicted_lora_uid = get_available_buffer_slot() - if evicted_lora_uid != "": - self.uid_to_buffer_id.pop(evicted_lora_uid) + buffer_id = get_available_buffer_slot() self.load_lora_weight_to_buffer( uid, buffer_id, lora_adapters.get(uid, None) )