[Perf] Refactor LoRAManager to eliminate stream syncs and redundant computations (#6994)
This commit is contained in:
@@ -81,7 +81,7 @@ class LoRAManager:
|
||||
seg_indptr=torch.zeros(
|
||||
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
||||
),
|
||||
max_len=0,
|
||||
max_len=1,
|
||||
weight_indices=torch.zeros(
|
||||
self.max_bs_in_cuda_graph, dtype=torch.int32
|
||||
),
|
||||
@@ -89,6 +89,17 @@ class LoRAManager:
|
||||
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
||||
)
|
||||
|
||||
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
||||
# across batches.
|
||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
|
||||
torch.cumsum(
|
||||
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
||||
dim=0,
|
||||
out=self.cuda_graph_batch_info.seg_indptr[
|
||||
1 : self.max_bs_in_cuda_graph + 1
|
||||
],
|
||||
)
|
||||
|
||||
def init_loras(self):
|
||||
# Config of each LoRA adapter
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
@@ -159,6 +170,45 @@ class LoRAManager:
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
def transfer_adapter_info(
|
||||
weight_indices_out: torch.Tensor,
|
||||
lora_ranks_out: torch.Tensor,
|
||||
scalings_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
||||
to device (CUDA) asynchronously.
|
||||
"""
|
||||
weight_indices = [0] * len(forward_batch.lora_paths)
|
||||
lora_ranks = [0] * self.max_loras_per_batch
|
||||
scalings = [0] * self.max_loras_per_batch
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||
weight_indices_tensor = torch.tensor(
|
||||
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
lora_ranks_tensor = torch.tensor(
|
||||
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||
)
|
||||
scalings_tensor = torch.tensor(
|
||||
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
||||
)
|
||||
|
||||
# Copy to device tensors asynchronously
|
||||
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
||||
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
||||
lora_ranks_tensor, non_blocking=True
|
||||
)
|
||||
scalings_out[: self.max_loras_per_batch].copy_(
|
||||
scalings_tensor, non_blocking=True
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(self, "max_bs_in_cuda_graph")
|
||||
and bs <= self.max_bs_in_cuda_graph
|
||||
@@ -166,51 +216,46 @@ class LoRAManager:
|
||||
):
|
||||
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
||||
# could use CUDA graph.
|
||||
self.cuda_graph_batch_info.bs = bs
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
||||
torch.cumsum(
|
||||
self.cuda_graph_batch_info.seg_lens[:bs],
|
||||
dim=0,
|
||||
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
||||
)
|
||||
self.cuda_graph_batch_info.max_len = 1
|
||||
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
self.cuda_graph_batch_info.weight_indices[i] = (
|
||||
self.memory_pool.get_buffer_id(lora_path)
|
||||
)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
self.cuda_graph_batch_info.lora_ranks[
|
||||
self.cuda_graph_batch_info.weight_indices[i]
|
||||
] = lora.config.hf_config["r"]
|
||||
self.cuda_graph_batch_info.scalings[
|
||||
self.cuda_graph_batch_info.weight_indices[i]
|
||||
] = lora.scaling
|
||||
transfer_adapter_info(
|
||||
self.cuda_graph_batch_info.weight_indices,
|
||||
self.cuda_graph_batch_info.lora_ranks,
|
||||
self.cuda_graph_batch_info.scalings,
|
||||
)
|
||||
|
||||
self.cuda_graph_batch_info.bs = bs
|
||||
self.cuda_graph_batch_info.max_len = 1
|
||||
batch_info = self.cuda_graph_batch_info
|
||||
else:
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
lora_ranks = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
scalings = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
||||
)
|
||||
transfer_adapter_info(
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
scalings,
|
||||
)
|
||||
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
|
||||
max_len = (
|
||||
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
||||
max(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else 1
|
||||
)
|
||||
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
|
||||
lora_ranks = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
scalings = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
|
||||
@@ -132,12 +132,13 @@ class LoRAMemoryPool:
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Prioritize empty slots
|
||||
if self.buffer_id_to_uid[buffer_id] == "":
|
||||
return buffer_id, ""
|
||||
return buffer_id
|
||||
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Evict unneeded lora
|
||||
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
||||
return buffer_id, self.buffer_id_to_uid[buffer_id]
|
||||
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
|
||||
return buffer_id
|
||||
|
||||
raise ValueError(
|
||||
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
||||
@@ -145,9 +146,7 @@ class LoRAMemoryPool:
|
||||
|
||||
for uid in cur_uids:
|
||||
if uid not in self.uid_to_buffer_id:
|
||||
buffer_id, evicted_lora_uid = get_available_buffer_slot()
|
||||
if evicted_lora_uid != "":
|
||||
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
||||
buffer_id = get_available_buffer_slot()
|
||||
self.load_lora_weight_to_buffer(
|
||||
uid, buffer_id, lora_adapters.get(uid, None)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user