[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class BaseLoRABackend:
|
||||
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
||||
Each backend has its own implementation of Lora kernels.
|
||||
|
||||
Args:
|
||||
name: name of backend
|
||||
batch_info: information of current batch for use
|
||||
max_loras_per_batch: maximum number of different lora weights
|
||||
that can be applied in a single forward batch.
|
||||
device: the device where the backend runs.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
self.name = name
|
||||
self.batch_info = batch_info
|
||||
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.device = device
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_batch_info(self, batch_info: LoRABatchInfo):
|
||||
self.batch_info = batch_info
|
||||
def init_cuda_graph_batch_info(
|
||||
self,
|
||||
cuda_graph_batch_info: LoRABatchInfo,
|
||||
max_bs_in_cuda_graph: int,
|
||||
):
|
||||
"""Initialize the batch info for CUDA Graph mode.
|
||||
|
||||
This method provides a hook for each backend to conduct its own initialization
|
||||
logic for CUDA Graph mode.
|
||||
|
||||
Args:
|
||||
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
||||
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_lora_batch(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
weight_indices: list[int],
|
||||
lora_ranks: list[int],
|
||||
scalings: list[float],
|
||||
batch_info: Optional[LoRABatchInfo] = None,
|
||||
):
|
||||
"""Prepare the lora weights and batch info for current forward batch.
|
||||
|
||||
This method provides a hook for each backend to conduct its own preparation
|
||||
logic for each forward batch.
|
||||
|
||||
Args:
|
||||
forward_batch: the ForwardBatch object for current forward pass
|
||||
weight_indices: list of indices of lora weights to be applied for current batch
|
||||
lora_ranks: list of lora ranks corresponding to weight_indices
|
||||
scalings: list of scaling factors corresponding to weight_indices
|
||||
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
||||
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||
|
||||
return TritonLoRABackend
|
||||
# elif name == "csgmv":
|
||||
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||
|
||||
# return ChunkedSgmvLoRABackend
|
||||
elif name == "flashinfer":
|
||||
raise ValueError(
|
||||
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
|
||||
sgemm_lora_b_fwd,
|
||||
)
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class TritonLoRABackend(BaseLoRABackend):
|
||||
name = "triton"
|
||||
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
||||
super().__init__(max_loras_per_batch, device)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def init_cuda_graph_batch_info(
|
||||
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
|
||||
):
|
||||
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
||||
# across batches.
|
||||
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
|
||||
torch.cumsum(
|
||||
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
|
||||
dim=0,
|
||||
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
|
||||
)
|
||||
|
||||
def prepare_lora_batch(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
weight_indices: list[int],
|
||||
lora_ranks: list[int],
|
||||
scalings: list[float],
|
||||
batch_info: Optional[LoRABatchInfo] = None,
|
||||
):
|
||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||
weight_indices_tensor = torch.tensor(
|
||||
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
lora_ranks_tensor = torch.tensor(
|
||||
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
scalings_tensor = torch.tensor(
|
||||
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
||||
)
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if batch_info is not None:
|
||||
assert (
|
||||
batch_info.use_cuda_graph
|
||||
), "batch_info.use_cuda_graph must be True when batch_info is provided"
|
||||
batch_info.bs = forward_batch.batch_size
|
||||
batch_info.num_segments = forward_batch.batch_size
|
||||
else:
|
||||
max_len = (
|
||||
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
||||
max(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else 1
|
||||
)
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=forward_batch.batch_size,
|
||||
num_segments=forward_batch.batch_size,
|
||||
max_len=max_len,
|
||||
use_cuda_graph=False,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=torch.empty(
|
||||
(bs,), dtype=torch.int32, device=self.device
|
||||
),
|
||||
lora_ranks=torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
||||
),
|
||||
scalings=torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
||||
),
|
||||
permutation=None,
|
||||
)
|
||||
|
||||
# Copy to device asynchronously
|
||||
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
||||
lora_ranks_tensor, non_blocking=True
|
||||
)
|
||||
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
||||
scalings_tensor, non_blocking=True
|
||||
)
|
||||
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
||||
|
||||
self.batch_info = batch_info
|
||||
|
||||
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_backend)
|
||||
shard_size = self.base_layer.output_partition_sizes[0]
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
shard_size,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=next(self.base_layer.parameters()).device,
|
||||
)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output,
|
||||
weights=self.B_buffer,
|
||||
output_offset=self.output_offset,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.A_buffer_gate_up = A_buffer
|
||||
self.B_buffer_gate_up = B_buffer
|
||||
|
||||
shard_size = self.base_layer.output_partition_sizes[0]
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
shard_size,
|
||||
2 * shard_size,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=next(self.base_layer.parameters()).device,
|
||||
)
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.lora_backend.run_gate_up_lora(
|
||||
x=x,
|
||||
gate_up_lora_a=self.A_buffer_gate_up,
|
||||
gate_up_lora_b=self.B_buffer_gate_up,
|
||||
output_offset=self.output_offset,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
output_size = self.base_layer.output_size
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_size,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=next(self.base_layer.parameters()).device,
|
||||
)
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output,
|
||||
weights=self.B_buffer,
|
||||
output_offset=self.output_offset,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
@@ -28,6 +28,9 @@ from torch import nn
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
|
||||
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if up_name not in weights:
|
||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||
assert self.lora_backend.name == "triton", (
|
||||
assert isinstance(self.lora_backend, TritonLoRABackend), (
|
||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||
f"or consider implementing custom initialization logic for other backends."
|
||||
|
||||
@@ -69,7 +69,10 @@ class LoRAManager:
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
backend_type = get_backend_from_name(lora_backend)
|
||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||
self.lora_backend: BaseLoRABackend = backend_type(
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Initialize mutable internal state of the LoRAManager.
|
||||
self.init_state(
|
||||
@@ -82,29 +85,22 @@ class LoRAManager:
|
||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||
with torch.device("cuda"):
|
||||
self.cuda_graph_batch_info = LoRABatchInfo(
|
||||
bs=self.max_bs_in_cuda_graph,
|
||||
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
||||
seg_indptr=torch.zeros(
|
||||
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
||||
),
|
||||
bs=max_bs_in_cuda_graph,
|
||||
use_cuda_graph=True,
|
||||
num_segments=None,
|
||||
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
|
||||
max_len=1,
|
||||
weight_indices=torch.zeros(
|
||||
self.max_bs_in_cuda_graph, dtype=torch.int32
|
||||
),
|
||||
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
||||
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
||||
)
|
||||
|
||||
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
||||
# across batches.
|
||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
|
||||
torch.cumsum(
|
||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
||||
dim=0,
|
||||
out=self.cuda_graph_batch_info.seg_indptr[
|
||||
1 : self.max_bs_in_cuda_graph + 1
|
||||
],
|
||||
)
|
||||
self.lora_backend.init_cuda_graph_batch_info(
|
||||
cuda_graph_batch_info=self.cuda_graph_batch_info,
|
||||
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
|
||||
)
|
||||
|
||||
def create_lora_update_result(
|
||||
self, success: bool, error_message: str = ""
|
||||
@@ -232,7 +228,6 @@ class LoRAManager:
|
||||
return required_slots <= mem_pool_vacancy
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
|
||||
# Load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_ids)
|
||||
|
||||
@@ -247,102 +242,30 @@ class LoRAManager:
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
def transfer_adapter_info(
|
||||
weight_indices_out: torch.Tensor,
|
||||
lora_ranks_out: torch.Tensor,
|
||||
scalings_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
||||
to device (CUDA) asynchronously.
|
||||
"""
|
||||
weight_indices = [0] * len(forward_batch.lora_ids)
|
||||
lora_ranks = [0] * self.max_loras_per_batch
|
||||
scalings = [0] * self.max_loras_per_batch
|
||||
for i, uid in enumerate(forward_batch.lora_ids):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||
if uid is not None:
|
||||
lora = self.loras[uid]
|
||||
lora_ranks[weight_indices[i]] = lora.config.r
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||
weight_indices_tensor = torch.tensor(
|
||||
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
lora_ranks_tensor = torch.tensor(
|
||||
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
scalings_tensor = torch.tensor(
|
||||
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
||||
)
|
||||
|
||||
# Copy to device tensors asynchronously
|
||||
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
||||
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
||||
lora_ranks_tensor, non_blocking=True
|
||||
)
|
||||
scalings_out[: self.max_loras_per_batch].copy_(
|
||||
scalings_tensor, non_blocking=True
|
||||
)
|
||||
|
||||
if (
|
||||
use_cuda_graph = (
|
||||
hasattr(self, "max_bs_in_cuda_graph")
|
||||
and bs <= self.max_bs_in_cuda_graph
|
||||
and forward_batch.forward_mode.is_cuda_graph()
|
||||
):
|
||||
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
||||
# could use CUDA graph.
|
||||
)
|
||||
|
||||
transfer_adapter_info(
|
||||
self.cuda_graph_batch_info.weight_indices,
|
||||
self.cuda_graph_batch_info.lora_ranks,
|
||||
self.cuda_graph_batch_info.scalings,
|
||||
)
|
||||
|
||||
self.cuda_graph_batch_info.bs = bs
|
||||
self.cuda_graph_batch_info.max_len = 1
|
||||
batch_info = self.cuda_graph_batch_info
|
||||
else:
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
lora_ranks = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
scalings = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
||||
)
|
||||
transfer_adapter_info(
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
scalings,
|
||||
)
|
||||
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
|
||||
max_len = (
|
||||
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
||||
max(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else 1
|
||||
)
|
||||
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
max_len=max_len,
|
||||
weight_indices=weight_indices,
|
||||
lora_ranks=lora_ranks,
|
||||
scalings=scalings,
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
weight_indices = [0] * len(forward_batch.lora_ids)
|
||||
lora_ranks = [0] * self.max_loras_per_batch
|
||||
scalings = [0] * self.max_loras_per_batch
|
||||
for i, uid in enumerate(forward_batch.lora_ids):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||
if uid is not None:
|
||||
lora = self.loras[uid]
|
||||
lora_ranks[weight_indices[i]] = lora.config.r
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
||||
# could use CUDA graph.
|
||||
self.lora_backend.prepare_lora_batch(
|
||||
forward_batch=forward_batch,
|
||||
weight_indices=weight_indices,
|
||||
lora_ranks=lora_ranks,
|
||||
scalings=scalings,
|
||||
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
|
||||
)
|
||||
|
||||
def update_lora_info(self):
|
||||
"""
|
||||
|
||||
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
|
||||
@dataclass
|
||||
class LoRABatchInfo:
|
||||
# The forward mode is using CUDA Graph.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
# Number of segments. For triton backend, it is equal to batch size.
|
||||
num_segments: int
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
# Indice pointers of each segment in shape (num_segments + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
# The index of lora adapter used by each segment, in shape (num_segments,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
# ranks of each lora adapter, in shape (lora_num,)
|
||||
@@ -31,6 +31,15 @@ class LoRABatchInfo:
|
||||
# scaling of each lora adapter, in shape (lora_num,)
|
||||
scalings: torch.Tensor
|
||||
|
||||
# Lengths of each segments in shape (num_segments,)
|
||||
seg_lens: Optional[torch.Tensor]
|
||||
|
||||
# Maximum segment length of current batch
|
||||
max_len: Optional[int]
|
||||
|
||||
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
||||
permutation: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class LoRAType(Enum):
|
||||
LORA_A = 0
|
||||
|
||||
Reference in New Issue
Block a user