Fix flaky issues of lora and add multi batch tests (#5957)

This commit is contained in:
Qiaolin Yu
2025-05-04 16:11:40 -04:00
committed by GitHub
parent 2b63798c7d
commit 3042f1da61
4 changed files with 205 additions and 96 deletions

View File

@@ -156,18 +156,15 @@ class LoRAManager:
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
# Do in-place updates when CUDA graph is enabled. Note that
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
# will also use these preallocated buffers, no matter whether
# the batch can use CUDA graph or not.
if (
hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph
and forward_batch.forward_mode.is_cuda_graph()
):
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
self.cuda_graph_batch_info.bs = bs
if forward_batch.forward_mode.is_extend():
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
forward_batch.extend_seq_lens
)
else:
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:bs],
dim=0,
@@ -201,10 +198,10 @@ class LoRAManager:
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.empty(
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):