Fix flaky issues of lora and add multi batch tests (#5957)
This commit is contained in:
@@ -156,18 +156,15 @@ class LoRAManager:
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
|
||||
# Do in-place updates when CUDA graph is enabled. Note that
|
||||
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
|
||||
# will also use these preallocated buffers, no matter whether
|
||||
# the batch can use CUDA graph or not.
|
||||
if (
|
||||
hasattr(self, "max_bs_in_cuda_graph")
|
||||
and bs <= self.max_bs_in_cuda_graph
|
||||
and forward_batch.forward_mode.is_cuda_graph()
|
||||
):
|
||||
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
||||
# could use CUDA graph.
|
||||
self.cuda_graph_batch_info.bs = bs
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
||||
forward_batch.extend_seq_lens
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
||||
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
||||
torch.cumsum(
|
||||
self.cuda_graph_batch_info.seg_lens[:bs],
|
||||
dim=0,
|
||||
@@ -201,10 +198,10 @@ class LoRAManager:
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
|
||||
lora_ranks = torch.empty(
|
||||
lora_ranks = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
scalings = torch.empty(
|
||||
scalings = torch.zeros(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
|
||||
Reference in New Issue
Block a user