[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
class BaseLoRABackend:
|
class BaseLoRABackend:
|
||||||
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
|||||||
Each backend has its own implementation of Lora kernels.
|
Each backend has its own implementation of Lora kernels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: name of backend
|
max_loras_per_batch: maximum number of different lora weights
|
||||||
batch_info: information of current batch for use
|
that can be applied in a single forward batch.
|
||||||
|
device: the device where the backend runs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
||||||
self.name = name
|
self.max_loras_per_batch = max_loras_per_batch
|
||||||
self.batch_info = batch_info
|
self.device = device
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
def run_lora_a_sgemm(
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||||
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_batch_info(self, batch_info: LoRABatchInfo):
|
def init_cuda_graph_batch_info(
|
||||||
self.batch_info = batch_info
|
self,
|
||||||
|
cuda_graph_batch_info: LoRABatchInfo,
|
||||||
|
max_bs_in_cuda_graph: int,
|
||||||
|
):
|
||||||
|
"""Initialize the batch info for CUDA Graph mode.
|
||||||
|
|
||||||
|
This method provides a hook for each backend to conduct its own initialization
|
||||||
|
logic for CUDA Graph mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
||||||
|
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_lora_batch(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
weight_indices: list[int],
|
||||||
|
lora_ranks: list[int],
|
||||||
|
scalings: list[float],
|
||||||
|
batch_info: Optional[LoRABatchInfo] = None,
|
||||||
|
):
|
||||||
|
"""Prepare the lora weights and batch info for current forward batch.
|
||||||
|
|
||||||
|
This method provides a hook for each backend to conduct its own preparation
|
||||||
|
logic for each forward batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_batch: the ForwardBatch object for current forward pass
|
||||||
|
weight_indices: list of indices of lora weights to be applied for current batch
|
||||||
|
lora_ranks: list of lora ranks corresponding to weight_indices
|
||||||
|
scalings: list of scaling factors corresponding to weight_indices
|
||||||
|
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
||||||
|
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||||
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|||||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
|
|
||||||
return TritonLoRABackend
|
return TritonLoRABackend
|
||||||
|
# elif name == "csgmv":
|
||||||
|
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||||
|
|
||||||
|
# return ChunkedSgmvLoRABackend
|
||||||
elif name == "flashinfer":
|
elif name == "flashinfer":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
|
|||||||
sgemm_lora_b_fwd,
|
sgemm_lora_b_fwd,
|
||||||
)
|
)
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
class TritonLoRABackend(BaseLoRABackend):
|
class TritonLoRABackend(BaseLoRABackend):
|
||||||
|
name = "triton"
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
||||||
super().__init__(name, batch_info)
|
super().__init__(max_loras_per_batch, device)
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
def run_lora_a_sgemm(
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||||
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
|
|||||||
base_output,
|
base_output,
|
||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
|
|
||||||
|
def init_cuda_graph_batch_info(
|
||||||
|
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
|
||||||
|
):
|
||||||
|
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
||||||
|
# across batches.
|
||||||
|
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
|
||||||
|
torch.cumsum(
|
||||||
|
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
|
||||||
|
dim=0,
|
||||||
|
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_lora_batch(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
weight_indices: list[int],
|
||||||
|
lora_ranks: list[int],
|
||||||
|
scalings: list[float],
|
||||||
|
batch_info: Optional[LoRABatchInfo] = None,
|
||||||
|
):
|
||||||
|
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||||
|
weight_indices_tensor = torch.tensor(
|
||||||
|
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
lora_ranks_tensor = torch.tensor(
|
||||||
|
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
scalings_tensor = torch.tensor(
|
||||||
|
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
|
if batch_info is not None:
|
||||||
|
assert (
|
||||||
|
batch_info.use_cuda_graph
|
||||||
|
), "batch_info.use_cuda_graph must be True when batch_info is provided"
|
||||||
|
batch_info.bs = forward_batch.batch_size
|
||||||
|
batch_info.num_segments = forward_batch.batch_size
|
||||||
|
else:
|
||||||
|
max_len = (
|
||||||
|
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
||||||
|
max(forward_batch.extend_seq_lens_cpu)
|
||||||
|
if forward_batch.forward_mode.is_extend()
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
seg_lens = (
|
||||||
|
forward_batch.extend_seq_lens
|
||||||
|
if forward_batch.forward_mode.is_extend()
|
||||||
|
else torch.ones(bs, device=self.device)
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||||
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||||
|
|
||||||
|
batch_info = LoRABatchInfo(
|
||||||
|
bs=forward_batch.batch_size,
|
||||||
|
num_segments=forward_batch.batch_size,
|
||||||
|
max_len=max_len,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
seg_lens=seg_lens,
|
||||||
|
seg_indptr=seg_indptr,
|
||||||
|
weight_indices=torch.empty(
|
||||||
|
(bs,), dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
lora_ranks=torch.empty(
|
||||||
|
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
||||||
|
),
|
||||||
|
scalings=torch.empty(
|
||||||
|
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
||||||
|
),
|
||||||
|
permutation=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy to device asynchronously
|
||||||
|
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
||||||
|
lora_ranks_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
||||||
|
scalings_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
||||||
|
|
||||||
|
self.batch_info = batch_info
|
||||||
|
|||||||
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_backend: BaseLoRABackend,
|
lora_backend: BaseLoRABackend,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(base_layer, lora_backend)
|
super().__init__(base_layer, lora_backend)
|
||||||
|
shard_size = self.base_layer.output_partition_sizes[0]
|
||||||
|
self.output_offset = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
shard_size,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=next(self.base_layer.parameters()).device,
|
||||||
|
)
|
||||||
|
|
||||||
def set_lora_info(
|
def set_lora_info(
|
||||||
self,
|
self,
|
||||||
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
x=lora_a_output,
|
x=lora_a_output,
|
||||||
weights=self.B_buffer,
|
weights=self.B_buffer,
|
||||||
|
output_offset=self.output_offset,
|
||||||
base_output=base_output,
|
base_output=base_output,
|
||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.A_buffer_gate_up = A_buffer
|
self.A_buffer_gate_up = A_buffer
|
||||||
self.B_buffer_gate_up = B_buffer
|
self.B_buffer_gate_up = B_buffer
|
||||||
|
|
||||||
|
shard_size = self.base_layer.output_partition_sizes[0]
|
||||||
|
self.output_offset = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
shard_size,
|
||||||
|
2 * shard_size,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=next(self.base_layer.parameters()).device,
|
||||||
|
)
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
lora_output = self.lora_backend.run_gate_up_lora(
|
lora_output = self.lora_backend.run_gate_up_lora(
|
||||||
x=x,
|
x=x,
|
||||||
gate_up_lora_a=self.A_buffer_gate_up,
|
gate_up_lora_a=self.A_buffer_gate_up,
|
||||||
gate_up_lora_b=self.B_buffer_gate_up,
|
gate_up_lora_b=self.B_buffer_gate_up,
|
||||||
|
output_offset=self.output_offset,
|
||||||
base_output=base_output,
|
base_output=base_output,
|
||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer = A_buffer
|
self.A_buffer = A_buffer
|
||||||
self.B_buffer = B_buffer
|
self.B_buffer = B_buffer
|
||||||
|
output_size = self.base_layer.output_size
|
||||||
|
self.output_offset = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
output_size,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=next(self.base_layer.parameters()).device,
|
||||||
|
)
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
x=lora_a_output,
|
x=lora_a_output,
|
||||||
weights=self.B_buffer,
|
weights=self.B_buffer,
|
||||||
|
output_offset=self.output_offset,
|
||||||
base_output=base_output,
|
base_output=base_output,
|
||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ from torch import nn
|
|||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
|
|
||||||
|
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||||
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||||
|
|
||||||
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
|
|||||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||||
if up_name not in weights:
|
if up_name not in weights:
|
||||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||||
assert self.lora_backend.name == "triton", (
|
assert isinstance(self.lora_backend, TritonLoRABackend), (
|
||||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||||
f"or consider implementing custom initialization logic for other backends."
|
f"or consider implementing custom initialization logic for other backends."
|
||||||
|
|||||||
@@ -69,7 +69,10 @@ class LoRAManager:
|
|||||||
# LoRA backend for running sgemm kernels
|
# LoRA backend for running sgemm kernels
|
||||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||||
backend_type = get_backend_from_name(lora_backend)
|
backend_type = get_backend_from_name(lora_backend)
|
||||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
self.lora_backend: BaseLoRABackend = backend_type(
|
||||||
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize mutable internal state of the LoRAManager.
|
# Initialize mutable internal state of the LoRAManager.
|
||||||
self.init_state(
|
self.init_state(
|
||||||
@@ -82,29 +85,22 @@ class LoRAManager:
|
|||||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.cuda_graph_batch_info = LoRABatchInfo(
|
self.cuda_graph_batch_info = LoRABatchInfo(
|
||||||
bs=self.max_bs_in_cuda_graph,
|
bs=max_bs_in_cuda_graph,
|
||||||
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
use_cuda_graph=True,
|
||||||
seg_indptr=torch.zeros(
|
num_segments=None,
|
||||||
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||||
),
|
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
|
||||||
max_len=1,
|
max_len=1,
|
||||||
weight_indices=torch.zeros(
|
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||||
self.max_bs_in_cuda_graph, dtype=torch.int32
|
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
||||||
),
|
|
||||||
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
||||||
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
self.lora_backend.init_cuda_graph_batch_info(
|
||||||
# across batches.
|
cuda_graph_batch_info=self.cuda_graph_batch_info,
|
||||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
|
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
|
||||||
torch.cumsum(
|
)
|
||||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
|
||||||
dim=0,
|
|
||||||
out=self.cuda_graph_batch_info.seg_indptr[
|
|
||||||
1 : self.max_bs_in_cuda_graph + 1
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_lora_update_result(
|
def create_lora_update_result(
|
||||||
self, success: bool, error_message: str = ""
|
self, success: bool, error_message: str = ""
|
||||||
@@ -232,7 +228,6 @@ class LoRAManager:
|
|||||||
return required_slots <= mem_pool_vacancy
|
return required_slots <= mem_pool_vacancy
|
||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
|
|
||||||
# Load active loras into lora memory pool
|
# Load active loras into lora memory pool
|
||||||
cur_uids = set(forward_batch.lora_ids)
|
cur_uids = set(forward_batch.lora_ids)
|
||||||
|
|
||||||
@@ -247,102 +242,30 @@ class LoRAManager:
|
|||||||
# set up batch info shared by all lora modules
|
# set up batch info shared by all lora modules
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
def transfer_adapter_info(
|
use_cuda_graph = (
|
||||||
weight_indices_out: torch.Tensor,
|
|
||||||
lora_ranks_out: torch.Tensor,
|
|
||||||
scalings_out: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
|
||||||
to device (CUDA) asynchronously.
|
|
||||||
"""
|
|
||||||
weight_indices = [0] * len(forward_batch.lora_ids)
|
|
||||||
lora_ranks = [0] * self.max_loras_per_batch
|
|
||||||
scalings = [0] * self.max_loras_per_batch
|
|
||||||
for i, uid in enumerate(forward_batch.lora_ids):
|
|
||||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
|
||||||
if uid is not None:
|
|
||||||
lora = self.loras[uid]
|
|
||||||
lora_ranks[weight_indices[i]] = lora.config.r
|
|
||||||
scalings[weight_indices[i]] = lora.scaling
|
|
||||||
|
|
||||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
|
||||||
weight_indices_tensor = torch.tensor(
|
|
||||||
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
|
||||||
)
|
|
||||||
lora_ranks_tensor = torch.tensor(
|
|
||||||
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
|
||||||
)
|
|
||||||
scalings_tensor = torch.tensor(
|
|
||||||
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy to device tensors asynchronously
|
|
||||||
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
|
||||||
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
|
||||||
lora_ranks_tensor, non_blocking=True
|
|
||||||
)
|
|
||||||
scalings_out[: self.max_loras_per_batch].copy_(
|
|
||||||
scalings_tensor, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(self, "max_bs_in_cuda_graph")
|
hasattr(self, "max_bs_in_cuda_graph")
|
||||||
and bs <= self.max_bs_in_cuda_graph
|
and bs <= self.max_bs_in_cuda_graph
|
||||||
and forward_batch.forward_mode.is_cuda_graph()
|
and forward_batch.forward_mode.is_cuda_graph()
|
||||||
):
|
)
|
||||||
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
|
||||||
# could use CUDA graph.
|
|
||||||
|
|
||||||
transfer_adapter_info(
|
weight_indices = [0] * len(forward_batch.lora_ids)
|
||||||
self.cuda_graph_batch_info.weight_indices,
|
lora_ranks = [0] * self.max_loras_per_batch
|
||||||
self.cuda_graph_batch_info.lora_ranks,
|
scalings = [0] * self.max_loras_per_batch
|
||||||
self.cuda_graph_batch_info.scalings,
|
for i, uid in enumerate(forward_batch.lora_ids):
|
||||||
)
|
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||||
|
if uid is not None:
|
||||||
self.cuda_graph_batch_info.bs = bs
|
lora = self.loras[uid]
|
||||||
self.cuda_graph_batch_info.max_len = 1
|
lora_ranks[weight_indices[i]] = lora.config.r
|
||||||
batch_info = self.cuda_graph_batch_info
|
scalings[weight_indices[i]] = lora.scaling
|
||||||
else:
|
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
||||||
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
# could use CUDA graph.
|
||||||
lora_ranks = torch.zeros(
|
self.lora_backend.prepare_lora_batch(
|
||||||
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
forward_batch=forward_batch,
|
||||||
)
|
weight_indices=weight_indices,
|
||||||
scalings = torch.zeros(
|
lora_ranks=lora_ranks,
|
||||||
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
scalings=scalings,
|
||||||
)
|
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
|
||||||
transfer_adapter_info(
|
)
|
||||||
weight_indices,
|
|
||||||
lora_ranks,
|
|
||||||
scalings,
|
|
||||||
)
|
|
||||||
|
|
||||||
seg_lens = (
|
|
||||||
forward_batch.extend_seq_lens
|
|
||||||
if forward_batch.forward_mode.is_extend()
|
|
||||||
else torch.ones(bs, device=self.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
max_len = (
|
|
||||||
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
|
||||||
max(forward_batch.extend_seq_lens_cpu)
|
|
||||||
if forward_batch.forward_mode.is_extend()
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
|
||||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
|
||||||
|
|
||||||
batch_info = LoRABatchInfo(
|
|
||||||
bs=bs,
|
|
||||||
seg_lens=seg_lens,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
max_len=max_len,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
lora_ranks=lora_ranks,
|
|
||||||
scalings=scalings,
|
|
||||||
)
|
|
||||||
self.lora_backend.set_batch_info(batch_info)
|
|
||||||
|
|
||||||
def update_lora_info(self):
|
def update_lora_info(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRABatchInfo:
|
class LoRABatchInfo:
|
||||||
|
# The forward mode is using CUDA Graph.
|
||||||
|
use_cuda_graph: bool
|
||||||
|
|
||||||
# Batch size
|
# Batch size
|
||||||
bs: int
|
bs: int
|
||||||
|
|
||||||
# Lengths of each sequence in shape (bs,)
|
# Number of segments. For triton backend, it is equal to batch size.
|
||||||
seg_lens: torch.Tensor
|
num_segments: int
|
||||||
|
|
||||||
# Indice pointers of each sequence in shape (bs + 1, )
|
# Indice pointers of each segment in shape (num_segments + 1, )
|
||||||
seg_indptr: torch.Tensor
|
seg_indptr: torch.Tensor
|
||||||
|
|
||||||
# Maximum sequence length of current batch
|
# The index of lora adapter used by each segment, in shape (num_segments,)
|
||||||
max_len: int
|
|
||||||
|
|
||||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
|
||||||
weight_indices: torch.Tensor
|
weight_indices: torch.Tensor
|
||||||
|
|
||||||
# ranks of each lora adapter, in shape (lora_num,)
|
# ranks of each lora adapter, in shape (lora_num,)
|
||||||
@@ -31,6 +31,15 @@ class LoRABatchInfo:
|
|||||||
# scaling of each lora adapter, in shape (lora_num,)
|
# scaling of each lora adapter, in shape (lora_num,)
|
||||||
scalings: torch.Tensor
|
scalings: torch.Tensor
|
||||||
|
|
||||||
|
# Lengths of each segments in shape (num_segments,)
|
||||||
|
seg_lens: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Maximum segment length of current batch
|
||||||
|
max_len: Optional[int]
|
||||||
|
|
||||||
|
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
||||||
|
permutation: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class LoRAType(Enum):
|
class LoRAType(Enum):
|
||||||
LORA_A = 0
|
LORA_A = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user