Feat: support cuda graph for LoRA (#4115)

Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-04-29 02:30:44 -04:00
committed by GitHub
parent 2c3ea29476
commit 8c0cfca87d
13 changed files with 366 additions and 55 deletions

View File

@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# TODO: avoid using contiguous() in GPU.
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
).contiguous()
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
self.B_buffer_gate_up = torch.empty(
(
B_buffer[0].shape[0],
2 * B_buffer[0].shape[1],
B_buffer[0].shape[2],
),
dtype=B_buffer[0].dtype,
device=B_buffer[0].device,
)
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
def __init__(
self,
base_layer: QKVParallelLinear,
lora_backend: BaseLoRABackend,
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
self.B_buffer_qkv = torch.empty(
(
B_buffer_q[0].shape[0],
output_dim_q + 2 * output_dim_kv,
B_buffer_q[0].shape[2],
),
dtype=B_buffer_q[0].dtype,
device=B_buffer_q[0].device,
)
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
B_buffer_kv[0]
)
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
B_buffer_kv[1]
)
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
if not hasattr(self, "output_offset") or self.output_offset is None:
self.output_offset = torch.empty(
4, dtype=torch.int32, device=B_buffer_q.device
)
self.output_offset[:4] = torch.tensor(
[
0,
output_dim_q,

View File

@@ -72,6 +72,23 @@ class LoRAManager:
self.init_loras()
self.init_lora_memory_pool()
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
max_len=0,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)
def init_loras(self):
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
@@ -140,39 +157,73 @@ class LoRAManager:
if cur_uids == set([None]):
return
# set up batch info shared by all lora moruldes
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
# Do in-place updates when CUDA graph is enabled. Note that
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
# will also use these preallocated buffers, no matter whether
# the batch can use CUDA graph or not.
self.cuda_graph_batch_info.bs = bs
if forward_batch.forward_mode.is_extend():
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
forward_batch.extend_seq_lens
)
else:
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:bs],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
)
self.cuda_graph_batch_info.max_len = int(
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
)
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
for i, lora_path in enumerate(forward_batch.lora_paths):
self.cuda_graph_batch_info.weight_indices[i] = (
self.memory_pool.get_buffer_id(lora_path)
)
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info
else:
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules

View File

@@ -220,6 +220,9 @@ class CudaGraphRunner:
if self.enable_torch_compile:
set_torch_compile_config()
if self.model_runner.server_args.lora_paths is not None:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -403,6 +406,13 @@ class CudaGraphRunner:
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# values if lora is enabled.
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
else:
lora_paths = None
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
@@ -424,8 +434,12 @@ class CudaGraphRunner:
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
)
if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,

View File

@@ -1242,7 +1242,6 @@ class ServerArgs:
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"