From 941002945b26f3b188038fa362df62333b502fe6 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Wed, 10 Sep 2025 09:58:37 -0700 Subject: [PATCH] [1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251) --- .../sglang/srt/lora/backend/base_backend.py | 58 ++++++- .../sglang/srt/lora/backend/triton_backend.py | 92 ++++++++++- python/sglang/srt/lora/layers.py | 32 ++++ python/sglang/srt/lora/lora.py | 5 +- python/sglang/srt/lora/lora_manager.py | 147 +++++------------- python/sglang/srt/lora/utils.py | 23 ++- 6 files changed, 227 insertions(+), 130 deletions(-) diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index fe8bd3d20..7c2c232d5 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -1,8 +1,9 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class BaseLoRABackend: @@ -10,13 +11,14 @@ class BaseLoRABackend: Each backend has its own implementation of Lora kernels. Args: - name: name of backend - batch_info: information of current batch for use + max_loras_per_batch: maximum number of different lora weights + that can be applied in a single forward batch. + device: the device where the backend runs. """ - def __init__(self, name: str, batch_info: LoRABatchInfo = None): - self.name = name - self.batch_info = batch_info + def __init__(self, max_loras_per_batch: int, device: torch.device): + self.max_loras_per_batch = max_loras_per_batch + self.device = device def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs @@ -93,8 +95,44 @@ class BaseLoRABackend: """ pass - def set_batch_info(self, batch_info: LoRABatchInfo): - self.batch_info = batch_info + def init_cuda_graph_batch_info( + self, + cuda_graph_batch_info: LoRABatchInfo, + max_bs_in_cuda_graph: int, + ): + """Initialize the batch info for CUDA Graph mode. + + This method provides a hook for each backend to conduct its own initialization + logic for CUDA Graph mode. + + Args: + cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager + max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode + """ + pass + + def prepare_lora_batch( + self, + forward_batch: ForwardBatch, + weight_indices: list[int], + lora_ranks: list[int], + scalings: list[float], + batch_info: Optional[LoRABatchInfo] = None, + ): + """Prepare the lora weights and batch info for current forward batch. + + This method provides a hook for each backend to conduct its own preparation + logic for each forward batch. + + Args: + forward_batch: the ForwardBatch object for current forward pass + weight_indices: list of indices of lora weights to be applied for current batch + lora_ranks: list of lora ranks corresponding to weight_indices + scalings: list of scaling factors corresponding to weight_indices + batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own + internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode) + """ + pass def get_backend_from_name(name: str) -> BaseLoRABackend: @@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: from sglang.srt.lora.backend.triton_backend import TritonLoRABackend return TritonLoRABackend + # elif name == "csgmv": + # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend + + # return ChunkedSgmvLoRABackend elif name == "flashinfer": raise ValueError( "FlashInfer LoRA backend has been deprecated, please use `triton` instead." diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index d3a854b40..7abeef770 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend @@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import ( sgemm_lora_b_fwd, ) from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class TritonLoRABackend(BaseLoRABackend): + name = "triton" - def __init__(self, name: str, batch_info: LoRABatchInfo = None): - super().__init__(name, batch_info) + def __init__(self, max_loras_per_batch: int, device: torch.device): + super().__init__(max_loras_per_batch, device) def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs @@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend): base_output, ) return lora_output + + def init_cuda_graph_batch_info( + self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int + ): + # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant + # across batches. + cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1) + torch.cumsum( + cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], + dim=0, + out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], + ) + + def prepare_lora_batch( + self, + forward_batch: ForwardBatch, + weight_indices: list[int], + lora_ranks: list[int], + scalings: list[float], + batch_info: Optional[LoRABatchInfo] = None, + ): + # Use pinned memory to avoid synchronizations during host-to-device transfer + weight_indices_tensor = torch.tensor( + weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" + ) + lora_ranks_tensor = torch.tensor( + lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" + ) + scalings_tensor = torch.tensor( + scalings, dtype=torch.float, pin_memory=True, device="cpu" + ) + + bs = forward_batch.batch_size + + if batch_info is not None: + assert ( + batch_info.use_cuda_graph + ), "batch_info.use_cuda_graph must be True when batch_info is provided" + batch_info.bs = forward_batch.batch_size + batch_info.num_segments = forward_batch.batch_size + else: + max_len = ( + # Calculate max_len from the CPU copy to avoid D2H transfer. + max(forward_batch.extend_seq_lens_cpu) + if forward_batch.forward_mode.is_extend() + else 1 + ) + seg_lens = ( + forward_batch.extend_seq_lens + if forward_batch.forward_mode.is_extend() + else torch.ones(bs, device=self.device) + ) + seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) + seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + + batch_info = LoRABatchInfo( + bs=forward_batch.batch_size, + num_segments=forward_batch.batch_size, + max_len=max_len, + use_cuda_graph=False, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + weight_indices=torch.empty( + (bs,), dtype=torch.int32, device=self.device + ), + lora_ranks=torch.empty( + (self.max_loras_per_batch,), dtype=torch.int64, device=self.device + ), + scalings=torch.empty( + (self.max_loras_per_batch,), dtype=torch.float, device=self.device + ), + permutation=None, + ) + + # Copy to device asynchronously + batch_info.lora_ranks[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + batch_info.scalings[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True) + + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index f9a877cd5..4426faccb 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): lora_backend: BaseLoRABackend, ) -> None: super().__init__(base_layer, lora_backend) + shard_size = self.base_layer.output_partition_sizes[0] + self.output_offset = torch.tensor( + [ + 0, + shard_size, + ], + dtype=torch.int32, + device=next(self.base_layer.parameters()).device, + ) def set_lora_info( self, @@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.B_buffer, + output_offset=self.output_offset, base_output=base_output, ) return lora_output @@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.A_buffer_gate_up = A_buffer self.B_buffer_gate_up = B_buffer + shard_size = self.base_layer.output_partition_sizes[0] + self.output_offset = torch.tensor( + [ + 0, + shard_size, + 2 * shard_size, + ], + dtype=torch.int32, + device=next(self.base_layer.parameters()).device, + ) + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: lora_output = self.lora_backend.run_gate_up_lora( x=x, gate_up_lora_a=self.A_buffer_gate_up, gate_up_lora_b=self.B_buffer_gate_up, + output_offset=self.output_offset, base_output=base_output, ) return lora_output @@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer + output_size = self.base_layer.output_size + self.output_offset = torch.tensor( + [ + 0, + output_size, + ], + dtype=torch.int32, + device=next(self.base_layer.parameters()).device, + ) def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.B_buffer, + output_offset=self.output_offset, base_output=base_output, ) return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index dfd5acda9..e7569624c 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -28,6 +28,9 @@ from torch import nn from sglang.srt.configs.load_config import LoadConfig from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.lora.backend.base_backend import BaseLoRABackend + +# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend +from sglang.srt.lora.backend.triton_backend import TritonLoRABackend from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_loader.loader import DefaultModelLoader @@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module): gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") if up_name not in weights: weights[up_name] = torch.zeros_like(weights[weight_name]) - assert self.lora_backend.name == "triton", ( + assert isinstance(self.lora_backend, TritonLoRABackend), ( f"LoRA weight initialization currently only supported for 'triton' backend. " f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"or consider implementing custom initialization logic for other backends." diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index e3560e05d..baf120ca2 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -69,7 +69,10 @@ class LoRAManager: # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") backend_type = get_backend_from_name(lora_backend) - self.lora_backend: BaseLoRABackend = backend_type(lora_backend) + self.lora_backend: BaseLoRABackend = backend_type( + max_loras_per_batch=max_loras_per_batch, + device=self.device, + ) # Initialize mutable internal state of the LoRAManager. self.init_state( @@ -82,29 +85,22 @@ class LoRAManager: self.max_bs_in_cuda_graph = max_bs_in_cuda_graph with torch.device("cuda"): self.cuda_graph_batch_info = LoRABatchInfo( - bs=self.max_bs_in_cuda_graph, - seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32), - seg_indptr=torch.zeros( - self.max_bs_in_cuda_graph + 1, dtype=torch.int32 - ), + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + num_segments=None, + seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), max_len=1, - weight_indices=torch.zeros( - self.max_bs_in_cuda_graph, dtype=torch.int32 - ), + weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), ) - # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant - # across batches. - self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1) - torch.cumsum( - self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph], - dim=0, - out=self.cuda_graph_batch_info.seg_indptr[ - 1 : self.max_bs_in_cuda_graph + 1 - ], - ) + self.lora_backend.init_cuda_graph_batch_info( + cuda_graph_batch_info=self.cuda_graph_batch_info, + max_bs_in_cuda_graph=max_bs_in_cuda_graph, + ) def create_lora_update_result( self, success: bool, error_message: str = "" @@ -232,7 +228,6 @@ class LoRAManager: return required_slots <= mem_pool_vacancy def prepare_lora_batch(self, forward_batch: ForwardBatch): - # Load active loras into lora memory pool cur_uids = set(forward_batch.lora_ids) @@ -247,102 +242,30 @@ class LoRAManager: # set up batch info shared by all lora modules bs = forward_batch.batch_size - def transfer_adapter_info( - weight_indices_out: torch.Tensor, - lora_ranks_out: torch.Tensor, - scalings_out: torch.Tensor, - ): - """ - Transfer adapter metadata (weight indices, LoRA rank, scalings) from host - to device (CUDA) asynchronously. - """ - weight_indices = [0] * len(forward_batch.lora_ids) - lora_ranks = [0] * self.max_loras_per_batch - scalings = [0] * self.max_loras_per_batch - for i, uid in enumerate(forward_batch.lora_ids): - weight_indices[i] = self.memory_pool.get_buffer_id(uid) - if uid is not None: - lora = self.loras[uid] - lora_ranks[weight_indices[i]] = lora.config.r - scalings[weight_indices[i]] = lora.scaling - - # Use pinned memory to avoid synchronizations during host-to-device transfer - weight_indices_tensor = torch.tensor( - weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" - ) - lora_ranks_tensor = torch.tensor( - lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" - ) - scalings_tensor = torch.tensor( - scalings, dtype=torch.float, pin_memory=True, device="cpu" - ) - - # Copy to device tensors asynchronously - weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True) - lora_ranks_out[: self.max_loras_per_batch].copy_( - lora_ranks_tensor, non_blocking=True - ) - scalings_out[: self.max_loras_per_batch].copy_( - scalings_tensor, non_blocking=True - ) - - if ( + use_cuda_graph = ( hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph and forward_batch.forward_mode.is_cuda_graph() - ): - # Do in-place updates when CUDA graph is enabled and the batch forward mode - # could use CUDA graph. + ) - transfer_adapter_info( - self.cuda_graph_batch_info.weight_indices, - self.cuda_graph_batch_info.lora_ranks, - self.cuda_graph_batch_info.scalings, - ) - - self.cuda_graph_batch_info.bs = bs - self.cuda_graph_batch_info.max_len = 1 - batch_info = self.cuda_graph_batch_info - else: - weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) - lora_ranks = torch.zeros( - (self.max_loras_per_batch,), dtype=torch.int64, device=self.device - ) - scalings = torch.zeros( - (self.max_loras_per_batch,), dtype=torch.float, device=self.device - ) - transfer_adapter_info( - weight_indices, - lora_ranks, - scalings, - ) - - seg_lens = ( - forward_batch.extend_seq_lens - if forward_batch.forward_mode.is_extend() - else torch.ones(bs, device=self.device) - ) - - max_len = ( - # Calculate max_len from the CPU copy to avoid D2H transfer. - max(forward_batch.extend_seq_lens_cpu) - if forward_batch.forward_mode.is_extend() - else 1 - ) - - seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) - seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) - - batch_info = LoRABatchInfo( - bs=bs, - seg_lens=seg_lens, - seg_indptr=seg_indptr, - max_len=max_len, - weight_indices=weight_indices, - lora_ranks=lora_ranks, - scalings=scalings, - ) - self.lora_backend.set_batch_info(batch_info) + weight_indices = [0] * len(forward_batch.lora_ids) + lora_ranks = [0] * self.max_loras_per_batch + scalings = [0] * self.max_loras_per_batch + for i, uid in enumerate(forward_batch.lora_ids): + weight_indices[i] = self.memory_pool.get_buffer_id(uid) + if uid is not None: + lora = self.loras[uid] + lora_ranks[weight_indices[i]] = lora.config.r + scalings[weight_indices[i]] = lora.scaling + # Do in-place updates when CUDA graph is enabled and the batch forward mode + # could use CUDA graph. + self.lora_backend.prepare_lora_batch( + forward_batch=forward_batch, + weight_indices=weight_indices, + lora_ranks=lora_ranks, + scalings=scalings, + batch_info=self.cuda_graph_batch_info if use_cuda_graph else None, + ) def update_lora_info(self): """ diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 6528e2691..459c943b7 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig @dataclass class LoRABatchInfo: + # The forward mode is using CUDA Graph. + use_cuda_graph: bool + # Batch size bs: int - # Lengths of each sequence in shape (bs,) - seg_lens: torch.Tensor + # Number of segments. For triton backend, it is equal to batch size. + num_segments: int - # Indice pointers of each sequence in shape (bs + 1, ) + # Indice pointers of each segment in shape (num_segments + 1, ) seg_indptr: torch.Tensor - # Maximum sequence length of current batch - max_len: int - - # The index of lora adapter used by each sequence, in shape (bs,) + # The index of lora adapter used by each segment, in shape (num_segments,) weight_indices: torch.Tensor # ranks of each lora adapter, in shape (lora_num,) @@ -31,6 +31,15 @@ class LoRABatchInfo: # scaling of each lora adapter, in shape (lora_num,) scalings: torch.Tensor + # Lengths of each segments in shape (num_segments,) + seg_lens: Optional[torch.Tensor] + + # Maximum segment length of current batch + max_len: Optional[int] + + # The logical (re)ordering of input rows (tokens), in shape (num_tokens,) + permutation: Optional[torch.Tensor] + class LoRAType(Enum): LORA_A = 0