Feat: support cuda graph for LoRA (#4115)
Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.set_lora = True
|
||||
self.A_buffer_gate_up = A_buffer
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
# TODO: avoid using contiguous() in GPU.
|
||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||
self.B_buffer_gate_up = torch.cat(
|
||||
(B_buffer[0], B_buffer[1]), dim=-2
|
||||
).contiguous()
|
||||
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
|
||||
self.B_buffer_gate_up = torch.empty(
|
||||
(
|
||||
B_buffer[0].shape[0],
|
||||
2 * B_buffer[0].shape[1],
|
||||
B_buffer[0].shape[2],
|
||||
),
|
||||
dtype=B_buffer[0].dtype,
|
||||
device=B_buffer[0].device,
|
||||
)
|
||||
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
||||
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
||||
else:
|
||||
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
||||
|
||||
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def init__(
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: QKVParallelLinear,
|
||||
lora_backend: BaseLoRABackend,
|
||||
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
self.B_buffer_qkv = torch.cat(
|
||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||
).contiguous()
|
||||
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
|
||||
self.B_buffer_qkv = torch.empty(
|
||||
(
|
||||
B_buffer_q[0].shape[0],
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
B_buffer_q[0].shape[2],
|
||||
),
|
||||
dtype=B_buffer_q[0].dtype,
|
||||
device=B_buffer_q[0].device,
|
||||
)
|
||||
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
||||
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
||||
B_buffer_kv[0]
|
||||
)
|
||||
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
||||
B_buffer_kv[1]
|
||||
)
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
self.output_offset = torch.tensor(
|
||||
if not hasattr(self, "output_offset") or self.output_offset is None:
|
||||
self.output_offset = torch.empty(
|
||||
4, dtype=torch.int32, device=B_buffer_q.device
|
||||
)
|
||||
self.output_offset[:4] = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
|
||||
@@ -72,6 +72,23 @@ class LoRAManager:
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
|
||||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||
with torch.device("cuda"):
|
||||
self.cuda_graph_batch_info = LoRABatchInfo(
|
||||
bs=self.max_bs_in_cuda_graph,
|
||||
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
||||
seg_indptr=torch.zeros(
|
||||
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
||||
),
|
||||
max_len=0,
|
||||
weight_indices=torch.zeros(
|
||||
self.max_bs_in_cuda_graph, dtype=torch.int32
|
||||
),
|
||||
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
||||
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
||||
)
|
||||
|
||||
def init_loras(self):
|
||||
# Config of each LoRA adapter
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
@@ -140,39 +157,73 @@ class LoRAManager:
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# set up batch info shared by all lora moruldes
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
|
||||
lora_ranks = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
scalings = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
|
||||
# Do in-place updates when CUDA graph is enabled. Note that
|
||||
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
|
||||
# will also use these preallocated buffers, no matter whether
|
||||
# the batch can use CUDA graph or not.
|
||||
self.cuda_graph_batch_info.bs = bs
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
||||
forward_batch.extend_seq_lens
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
||||
torch.cumsum(
|
||||
self.cuda_graph_batch_info.seg_lens[:bs],
|
||||
dim=0,
|
||||
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
||||
)
|
||||
self.cuda_graph_batch_info.max_len = int(
|
||||
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
||||
)
|
||||
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
max_len=max_len,
|
||||
weight_indices=weight_indices,
|
||||
lora_ranks=lora_ranks,
|
||||
scalings=scalings,
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
self.cuda_graph_batch_info.weight_indices[i] = (
|
||||
self.memory_pool.get_buffer_id(lora_path)
|
||||
)
|
||||
lora = self.loras[lora_path]
|
||||
self.cuda_graph_batch_info.lora_ranks[
|
||||
self.cuda_graph_batch_info.weight_indices[i]
|
||||
] = lora.config.hf_config["r"]
|
||||
self.cuda_graph_batch_info.scalings[
|
||||
self.cuda_graph_batch_info.weight_indices[i]
|
||||
] = lora.scaling
|
||||
batch_info = self.cuda_graph_batch_info
|
||||
else:
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
|
||||
lora_ranks = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
scalings = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
max_len=max_len,
|
||||
weight_indices=weight_indices,
|
||||
lora_ranks=lora_ranks,
|
||||
scalings=scalings,
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# call set_lora_info for each lora modules
|
||||
|
||||
@@ -220,6 +220,9 @@ class CudaGraphRunner:
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
if self.model_runner.server_args.lora_paths is not None:
|
||||
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
||||
|
||||
# Graph inputs
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
@@ -403,6 +406,13 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = (
|
||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||
)
|
||||
if self.model_runner.server_args.lora_paths is not None:
|
||||
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
|
||||
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
|
||||
# values if lora is enabled.
|
||||
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
|
||||
else:
|
||||
lora_paths = None
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
@@ -424,8 +434,12 @@ class CudaGraphRunner:
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
if lora_paths is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
|
||||
@@ -1242,7 +1242,6 @@ class ServerArgs:
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
|
||||
Reference in New Issue
Block a user