Fix lora batch processing when input lora_path contains None (#5930)

This commit is contained in:
Qiaolin Yu
2025-04-30 22:42:42 -04:00
committed by GitHub
parent 11383cec3c
commit 7bcd8b1cb2
4 changed files with 60 additions and 279 deletions

View File

@@ -153,10 +153,6 @@ class LoRAManager:
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# FIXME: Handle lora uid with None more safely
if cur_uids == set([None]):
return
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
@@ -185,13 +181,14 @@ class LoRAManager:
self.cuda_graph_batch_info.weight_indices[i] = (
self.memory_pool.get_buffer_id(lora_path)
)
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
if lora_path is not None:
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info
else:
seg_lens = (
@@ -212,9 +209,10 @@ class LoRAManager:
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,

View File

@@ -423,9 +423,9 @@ class HFRunner:
)
del input_logits
if lora_paths is not None and lora_paths[i] is not None:
# Unload the LoRA adapter if it is used
model.unload()
if lora_paths is not None and lora_paths[i] is not None:
# Unload the LoRA adapter if it is used
model.unload()
return ModelOutput(
output_strs=output_strs,